diff --git a/test/test_backbone_utils.py b/test/test_backbone_utils.py index ed9b52d0499..e07194a5685 100644 --- a/test/test_backbone_utils.py +++ b/test/test_backbone_utils.py @@ -13,7 +13,11 @@ def get_available_models(): # TODO add a registration mechanism to torchvision.models - return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] + return [ + k + for k, v in models.__dict__.items() + if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k != "get_weight" + ] @pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50")) diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 6c7234e2ef0..4dde47d40c3 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -4,10 +4,11 @@ import pytest import test_models as TM import torch +import torchvision from common_utils import cpu_and_gpu, needs_cuda +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._utils import handle_legacy_interface from torchvision.prototype import models -from torchvision.prototype.models._api import WeightsEnum, Weights -from torchvision.prototype.models._utils import handle_legacy_interface run_if_test_with_prototype = pytest.mark.skipif( os.getenv("PYTORCH_TEST_WITH_PROTOTYPE") != "1", @@ -54,27 +55,27 @@ def _build_model(fn, **kwargs): @pytest.mark.parametrize( "name, weight", [ - ("ResNet50_Weights.IMAGENET1K_V1", models.ResNet50_Weights.IMAGENET1K_V1), - ("ResNet50_Weights.DEFAULT", models.ResNet50_Weights.IMAGENET1K_V2), + ("ResNet50_Weights.IMAGENET1K_V1", torchvision.models.ResNet50_Weights.IMAGENET1K_V1), + ("ResNet50_Weights.DEFAULT", torchvision.models.ResNet50_Weights.IMAGENET1K_V2), ( "ResNet50_QuantizedWeights.DEFAULT", - models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2, + torchvision.models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2, ), ( "ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1", - models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1, + torchvision.models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1, ), ], ) def test_get_weight(name, weight): - assert models.get_weight(name) == weight + assert torchvision.models.get_weight(name) == weight @pytest.mark.parametrize( "model_fn", - TM.get_models_from_module(models) + TM.get_models_from_module(torchvision.models) + TM.get_models_from_module(models.detection) - + TM.get_models_from_module(models.quantization) + + TM.get_models_from_module(torchvision.models.quantization) + TM.get_models_from_module(models.segmentation) + TM.get_models_from_module(models.video) + TM.get_models_from_module(models.optical_flow), @@ -88,9 +89,9 @@ def test_naming_conventions(model_fn): @pytest.mark.parametrize( "model_fn", - TM.get_models_from_module(models) + TM.get_models_from_module(torchvision.models) + TM.get_models_from_module(models.detection) - + TM.get_models_from_module(models.quantization) + + TM.get_models_from_module(torchvision.models.quantization) + TM.get_models_from_module(models.segmentation) + TM.get_models_from_module(models.video) + TM.get_models_from_module(models.optical_flow), @@ -142,13 +143,6 @@ def test_schema_meta_validation(model_fn): assert not bad_names -@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models)) -@pytest.mark.parametrize("dev", cpu_and_gpu()) -@run_if_test_with_prototype -def test_classification_model(model_fn, dev): - TM.test_classification_model(model_fn, dev) - - @pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.detection)) @pytest.mark.parametrize("dev", cpu_and_gpu()) @run_if_test_with_prototype @@ -156,12 +150,6 @@ def test_detection_model(model_fn, dev): TM.test_detection_model(model_fn, dev) -@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.quantization)) -@run_if_test_with_prototype -def test_quantized_classification_model(model_fn): - TM.test_quantized_classification_model(model_fn) - - @pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.segmentation)) @pytest.mark.parametrize("dev", cpu_and_gpu()) @run_if_test_with_prototype @@ -186,9 +174,7 @@ def test_raft(model_builder, scripted): @pytest.mark.parametrize( "model_fn", - TM.get_models_from_module(models) - + TM.get_models_from_module(models.detection) - + TM.get_models_from_module(models.quantization) + TM.get_models_from_module(models.detection) + TM.get_models_from_module(models.segmentation) + TM.get_models_from_module(models.video) + TM.get_models_from_module(models.optical_flow), diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 16495e8552e..83e49908348 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -1,20 +1,20 @@ from .alexnet import * from .convnext import * -from .resnet import * -from .vgg import * -from .squeezenet import * -from .inception import * from .densenet import * +from .efficientnet import * from .googlenet import * -from .mobilenet import * +from .inception import * from .mnasnet import * -from .shufflenetv2 import * -from .efficientnet import * +from .mobilenet import * from .regnet import * +from .resnet import * +from .shufflenetv2 import * +from .squeezenet import * +from .vgg import * from .vision_transformer import * from . import detection -from . import feature_extraction from . import optical_flow from . import quantization from . import segmentation from . import video +from ._api import get_weight diff --git a/torchvision/prototype/models/_api.py b/torchvision/models/_api.py similarity index 98% rename from torchvision/prototype/models/_api.py rename to torchvision/models/_api.py index 85b280a7dfc..d841415a45a 100644 --- a/torchvision/prototype/models/_api.py +++ b/torchvision/models/_api.py @@ -7,7 +7,7 @@ from torchvision._utils import StrEnum -from ..._internally_replaced_utils import load_state_dict_from_url +from .._internally_replaced_utils import load_state_dict_from_url __all__ = ["WeightsEnum", "Weights", "get_weight"] diff --git a/torchvision/prototype/models/_meta.py b/torchvision/models/_meta.py similarity index 100% rename from torchvision/prototype/models/_meta.py rename to torchvision/models/_meta.py diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py index f4e1cd84508..9e3a81411a1 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/_utils.py @@ -1,8 +1,14 @@ +import functools +import inspect +import warnings from collections import OrderedDict -from typing import Dict, Optional +from typing import Any, Dict, Optional, TypeVar, Callable, Tuple, Union from torch import nn +from .._utils import sequence_to_str +from ._api import WeightsEnum + class IntermediateLayerGetter(nn.ModuleDict): """ @@ -81,3 +87,158 @@ def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> if new_v < 0.9 * v: new_v += divisor return new_v + + +D = TypeVar("D") + + +def kwonly_to_pos_or_kw(fn: Callable[..., D]) -> Callable[..., D]: + """Decorates a function that uses keyword only parameters to also allow them being passed as positionals. + + For example, consider the use case of changing the signature of ``old_fn`` into the one from ``new_fn``: + + .. code:: + + def old_fn(foo, bar, baz=None): + ... + + def new_fn(foo, *, bar, baz=None): + ... + + Calling ``old_fn("foo", "bar, "baz")`` was valid, but the same call is no longer valid with ``new_fn``. To keep BC + and at the same time warn the user of the deprecation, this decorator can be used: + + .. code:: + + @kwonly_to_pos_or_kw + def new_fn(foo, *, bar, baz=None): + ... + + new_fn("foo", "bar, "baz") + """ + params = inspect.signature(fn).parameters + + try: + keyword_only_start_idx = next( + idx for idx, param in enumerate(params.values()) if param.kind == param.KEYWORD_ONLY + ) + except StopIteration: + raise TypeError(f"Found no keyword-only parameter on function '{fn.__name__}'") from None + + keyword_only_params = tuple(inspect.signature(fn).parameters)[keyword_only_start_idx:] + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> D: + args, keyword_only_args = args[:keyword_only_start_idx], args[keyword_only_start_idx:] + if keyword_only_args: + keyword_only_kwargs = dict(zip(keyword_only_params, keyword_only_args)) + warnings.warn( + f"Using {sequence_to_str(tuple(keyword_only_kwargs.keys()), separate_last='and ')} as positional " + f"parameter(s) is deprecated. Please use keyword parameter(s) instead." + ) + kwargs.update(keyword_only_kwargs) + + return fn(*args, **kwargs) + + return wrapper + + +W = TypeVar("W", bound=WeightsEnum) +M = TypeVar("M", bound=nn.Module) +V = TypeVar("V") + + +def handle_legacy_interface(**weights: Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): + """Decorates a model builder with the new interface to make it compatible with the old. + + In particular this handles two things: + + 1. Allows positional parameters again, but emits a deprecation warning in case they are used. See + :func:`torchvision.prototype.utils._internal.kwonly_to_pos_or_kw` for details. + 2. Handles the default value change from ``pretrained=False`` to ``weights=None`` and ``pretrained=True`` to + ``weights=Weights`` and emits a deprecation warning with instructions for the new interface. + + Args: + **weights (Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): Deprecated parameter + name and default value for the legacy ``pretrained=True``. The default value can be a callable in which + case it will be called with a dictionary of the keyword arguments. The only key that is guaranteed to be in + the dictionary is the deprecated parameter name passed as first element in the tuple. All other parameters + should be accessed with :meth:`~dict.get`. + """ + + def outer_wrapper(builder: Callable[..., M]) -> Callable[..., M]: + @kwonly_to_pos_or_kw + @functools.wraps(builder) + def inner_wrapper(*args: Any, **kwargs: Any) -> M: + for weights_param, (pretrained_param, default) in weights.items(): # type: ignore[union-attr] + # If neither the weights nor the pretrained parameter as passed, or the weights argument already use + # the new style arguments, there is nothing to do. Note that we cannot use `None` as sentinel for the + # weight argument, since it is a valid value. + sentinel = object() + weights_arg = kwargs.get(weights_param, sentinel) + if ( + (weights_param not in kwargs and pretrained_param not in kwargs) + or isinstance(weights_arg, WeightsEnum) + or (isinstance(weights_arg, str) and weights_arg != "legacy") + or weights_arg is None + ): + continue + + # If the pretrained parameter was passed as positional argument, it is now mapped to + # `kwargs[weights_param]`. This happens because the @kwonly_to_pos_or_kw decorator uses the current + # signature to infer the names of positionally passed arguments and thus has no knowledge that there + # used to be a pretrained parameter. + pretrained_positional = weights_arg is not sentinel + if pretrained_positional: + # We put the pretrained argument under its legacy name in the keyword argument dictionary to have a + # unified access to the value if the default value is a callable. + kwargs[pretrained_param] = pretrained_arg = kwargs.pop(weights_param) + else: + pretrained_arg = kwargs[pretrained_param] + + if pretrained_arg: + default_weights_arg = default(kwargs) if callable(default) else default + if not isinstance(default_weights_arg, WeightsEnum): + raise ValueError(f"No weights available for model {builder.__name__}") + else: + default_weights_arg = None + + if not pretrained_positional: + warnings.warn( + f"The parameter '{pretrained_param}' is deprecated, please use '{weights_param}' instead." + ) + + msg = ( + f"Arguments other than a weight enum or `None` for '{weights_param}' are deprecated. " + f"The current behavior is equivalent to passing `{weights_param}={default_weights_arg}`." + ) + if pretrained_arg: + msg = ( + f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.DEFAULT` " + f"to get the most up-to-date weights." + ) + warnings.warn(msg) + + del kwargs[pretrained_param] + kwargs[weights_param] = default_weights_arg + + return builder(*args, **kwargs) + + return inner_wrapper + + return outer_wrapper + + +def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> None: + if param in kwargs: + if kwargs[param] != new_value: + raise ValueError(f"The parameter '{param}' expected value {new_value} but got {kwargs[param]} instead.") + else: + kwargs[param] = new_value + + +def _ovewrite_value_param(param: Optional[V], new_value: V) -> V: + if param is not None: + if param != new_value: + raise ValueError(f"The parameter '{param}' expected value {new_value} but got {param} instead.") + return new_value diff --git a/torchvision/models/alexnet.py b/torchvision/models/alexnet.py index bb812febdc4..4df533000f9 100644 --- a/torchvision/models/alexnet.py +++ b/torchvision/models/alexnet.py @@ -1,18 +1,17 @@ -from typing import Any +from functools import partial +from typing import Any, Optional import torch import torch.nn as nn -from .._internally_replaced_utils import load_state_dict_from_url +from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param -__all__ = ["AlexNet", "alexnet"] - - -model_urls = { - "alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", -} +__all__ = ["AlexNet", "AlexNet_Weights", "alexnet"] class AlexNet(nn.Module): @@ -53,17 +52,45 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -def alexnet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> AlexNet: +class AlexNet_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + "task": "image_classification", + "architecture": "AlexNet", + "publication_year": 2012, + "num_params": 61100840, + "size": (224, 224), + "min_size": (63, 63), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg", + "acc@1": 56.522, + "acc@5": 79.066, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1)) +def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet: r"""AlexNet model architecture from the `"One weird trick..." `_ paper. The required minimum input size of the model is 63x63. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (AlexNet_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = AlexNet_Weights.verify(weights) + + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = AlexNet(**kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls["alexnet"], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py index 3a0dcdb31cd..8d25e77eaa1 100644 --- a/torchvision/models/convnext.py +++ b/torchvision/models/convnext.py @@ -1,18 +1,25 @@ from functools import partial -from typing import Any, Callable, Dict, List, Optional, Sequence +from typing import Any, Callable, List, Optional, Sequence import torch from torch import nn, Tensor from torch.nn import functional as F -from .._internally_replaced_utils import load_state_dict_from_url from ..ops.misc import Conv2dNormActivation from ..ops.stochastic_depth import StochasticDepth +from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param __all__ = [ "ConvNeXt", + "ConvNeXt_Tiny_Weights", + "ConvNeXt_Small_Weights", + "ConvNeXt_Base_Weights", + "ConvNeXt_Large_Weights", "convnext_tiny", "convnext_small", "convnext_base", @@ -20,14 +27,6 @@ ] -_MODELS_URLS: Dict[str, Optional[str]] = { - "convnext_tiny": "https://download.pytorch.org/models/convnext_tiny-983f1562.pth", - "convnext_small": "https://download.pytorch.org/models/convnext_small-0c510722.pth", - "convnext_base": "https://download.pytorch.org/models/convnext_base-6075fbad.pth", - "convnext_large": "https://download.pytorch.org/models/convnext_large-ea097f82.pth", -} - - class LayerNorm2d(nn.LayerNorm): def forward(self, x: Tensor) -> Tensor: x = x.permute(0, 2, 3, 1) @@ -187,29 +186,101 @@ def forward(self, x: Tensor) -> Tensor: def _convnext( - arch: str, block_setting: List[CNBlockConfig], stochastic_depth_prob: float, - pretrained: bool, + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> ConvNeXt: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs) - if pretrained: - if arch not in _MODELS_URLS: - raise ValueError(f"No checkpoint is available for model type {arch}") - state_dict = load_state_dict_from_url(_MODELS_URLS[arch], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model -def convnext_tiny(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: +_COMMON_META = { + "task": "image_classification", + "architecture": "ConvNeXt", + "publication_year": 2022, + "size": (224, 224), + "min_size": (32, 32), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext", +} + + +class ConvNeXt_Tiny_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=236), + meta={ + **_COMMON_META, + "num_params": 28589128, + "acc@1": 82.520, + "acc@5": 96.146, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ConvNeXt_Small_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/convnext_small-0c510722.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=230), + meta={ + **_COMMON_META, + "num_params": 50223688, + "acc@1": 83.616, + "acc@5": 96.650, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ConvNeXt_Base_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/convnext_base-6075fbad.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 88591464, + "acc@1": 84.062, + "acc@5": 96.870, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ConvNeXt_Large_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/convnext_large-ea097f82.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 197767336, + "acc@1": 84.414, + "acc@5": 96.976, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1)) +def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: r"""ConvNeXt Tiny model architecture from the `"A ConvNet for the 2020s" `_ paper. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ConvNeXt_Tiny_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = ConvNeXt_Tiny_Weights.verify(weights) + block_setting = [ CNBlockConfig(96, 192, 3), CNBlockConfig(192, 384, 3), @@ -217,16 +288,21 @@ def convnext_tiny(*, pretrained: bool = False, progress: bool = True, **kwargs: CNBlockConfig(768, None, 3), ] stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1) - return _convnext("convnext_tiny", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs) + return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) -def convnext_small(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: +@handle_legacy_interface(weights=("pretrained", ConvNeXt_Small_Weights.IMAGENET1K_V1)) +def convnext_small( + *, weights: Optional[ConvNeXt_Small_Weights] = None, progress: bool = True, **kwargs: Any +) -> ConvNeXt: r"""ConvNeXt Small model architecture from the `"A ConvNet for the 2020s" `_ paper. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ConvNeXt_Small_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = ConvNeXt_Small_Weights.verify(weights) + block_setting = [ CNBlockConfig(96, 192, 3), CNBlockConfig(192, 384, 3), @@ -234,16 +310,19 @@ def convnext_small(*, pretrained: bool = False, progress: bool = True, **kwargs: CNBlockConfig(768, None, 3), ] stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.4) - return _convnext("convnext_small", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs) + return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) -def convnext_base(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: +@handle_legacy_interface(weights=("pretrained", ConvNeXt_Base_Weights.IMAGENET1K_V1)) +def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: r"""ConvNeXt Base model architecture from the `"A ConvNet for the 2020s" `_ paper. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ConvNeXt_Base_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = ConvNeXt_Base_Weights.verify(weights) + block_setting = [ CNBlockConfig(128, 256, 3), CNBlockConfig(256, 512, 3), @@ -251,16 +330,21 @@ def convnext_base(*, pretrained: bool = False, progress: bool = True, **kwargs: CNBlockConfig(1024, None, 3), ] stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5) - return _convnext("convnext_base", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs) + return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) -def convnext_large(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: +@handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1)) +def convnext_large( + *, weights: Optional[ConvNeXt_Large_Weights] = None, progress: bool = True, **kwargs: Any +) -> ConvNeXt: r"""ConvNeXt Large model architecture from the `"A ConvNet for the 2020s" `_ paper. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ConvNeXt_Large_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = ConvNeXt_Large_Weights.verify(weights) + block_setting = [ CNBlockConfig(192, 384, 3), CNBlockConfig(384, 768, 3), @@ -268,4 +352,4 @@ def convnext_large(*, pretrained: bool = False, progress: bool = True, **kwargs: CNBlockConfig(1536, None, 3), ] stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5) - return _convnext("convnext_large", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs) + return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) diff --git a/torchvision/models/densenet.py b/torchvision/models/densenet.py index 14e318360af..b0de4529902 100644 --- a/torchvision/models/densenet.py +++ b/torchvision/models/densenet.py @@ -1,6 +1,7 @@ import re from collections import OrderedDict -from typing import Any, List, Tuple +from functools import partial +from typing import Any, List, Optional, Tuple import torch import torch.nn as nn @@ -8,18 +9,24 @@ import torch.utils.checkpoint as cp from torch import Tensor -from .._internally_replaced_utils import load_state_dict_from_url +from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param -__all__ = ["DenseNet", "densenet121", "densenet169", "densenet201", "densenet161"] - -model_urls = { - "densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth", - "densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth", - "densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth", - "densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth", -} +__all__ = [ + "DenseNet", + "DenseNet121_Weights", + "DenseNet161_Weights", + "DenseNet169_Weights", + "DenseNet201_Weights", + "densenet121", + "densenet161", + "densenet169", + "densenet201", +] class _DenseLayer(nn.Module): @@ -220,7 +227,7 @@ def forward(self, x: Tensor) -> Tensor: return out -def _load_state_dict(model: nn.Module, model_url: str, progress: bool) -> None: +def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) -> None: # '.'s are no longer allowed in module names, but previous _DenseLayer # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. # They are also in the checkpoints in model_urls. This pattern is used @@ -229,7 +236,7 @@ def _load_state_dict(model: nn.Module, model_url: str, progress: bool) -> None: r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" ) - state_dict = load_state_dict_from_url(model_url, progress=progress) + state_dict = weights.get_state_dict(progress=progress) for key in list(state_dict.keys()): res = pattern.match(key) if res: @@ -240,71 +247,155 @@ def _load_state_dict(model: nn.Module, model_url: str, progress: bool) -> None: def _densenet( - arch: str, growth_rate: int, block_config: Tuple[int, int, int, int], num_init_features: int, - pretrained: bool, + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> DenseNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) - if pretrained: - _load_state_dict(model, model_urls[arch], progress) + + if weights is not None: + _load_state_dict(model=model, weights=weights, progress=progress) + return model -def densenet121(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: +_COMMON_META = { + "task": "image_classification", + "architecture": "DenseNet", + "publication_year": 2016, + "size": (224, 224), + "min_size": (29, 29), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/pull/116", +} + + +class DenseNet121_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/densenet121-a639ec97.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 7978856, + "acc@1": 74.434, + "acc@5": 91.972, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class DenseNet161_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/densenet161-8d451a50.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 28681000, + "acc@1": 77.138, + "acc@5": 93.560, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class DenseNet169_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/densenet169-b2777c0a.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 14149480, + "acc@1": 75.600, + "acc@5": 92.806, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class DenseNet201_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/densenet201-c1103571.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 20013928, + "acc@1": 76.896, + "acc@5": 93.370, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.IMAGENET1K_V1)) +def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: r"""Densenet-121 model from `"Densely Connected Convolutional Networks" `_. The required minimum input size of the model is 29x29. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (DenseNet121_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, but slower. Default: *False*. See `"paper" `_. """ - return _densenet("densenet121", 32, (6, 12, 24, 16), 64, pretrained, progress, **kwargs) + weights = DenseNet121_Weights.verify(weights) + return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs) -def densenet161(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: + +@handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.IMAGENET1K_V1)) +def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: r"""Densenet-161 model from `"Densely Connected Convolutional Networks" `_. The required minimum input size of the model is 29x29. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (DenseNet161_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, but slower. Default: *False*. See `"paper" `_. """ - return _densenet("densenet161", 48, (6, 12, 36, 24), 96, pretrained, progress, **kwargs) + weights = DenseNet161_Weights.verify(weights) + + return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs) -def densenet169(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: +@handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.IMAGENET1K_V1)) +def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: r"""Densenet-169 model from `"Densely Connected Convolutional Networks" `_. The required minimum input size of the model is 29x29. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (DenseNet169_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, but slower. Default: *False*. See `"paper" `_. """ - return _densenet("densenet169", 32, (6, 12, 32, 32), 64, pretrained, progress, **kwargs) + weights = DenseNet169_Weights.verify(weights) + return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs) -def densenet201(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: + +@handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.IMAGENET1K_V1)) +def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: r"""Densenet-201 model from `"Densely Connected Convolutional Networks" `_. The required minimum input size of the model is 29x29. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (DenseNet201_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, but slower. Default: *False*. See `"paper" `_. """ - return _densenet("densenet201", 32, (6, 12, 48, 32), 64, pretrained, progress, **kwargs) + weights = DenseNet201_Weights.verify(weights) + + return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs) diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index f8238912ffd..9665c169bbf 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -9,14 +9,27 @@ from torch import nn, Tensor from torchvision.ops import StochasticDepth -from .._internally_replaced_utils import load_state_dict_from_url from ..ops.misc import Conv2dNormActivation, SqueezeExcitation +from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once -from ._utils import _make_divisible +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible __all__ = [ "EfficientNet", + "EfficientNet_B0_Weights", + "EfficientNet_B1_Weights", + "EfficientNet_B2_Weights", + "EfficientNet_B3_Weights", + "EfficientNet_B4_Weights", + "EfficientNet_B5_Weights", + "EfficientNet_B6_Weights", + "EfficientNet_B7_Weights", + "EfficientNet_V2_S_Weights", + "EfficientNet_V2_M_Weights", + "EfficientNet_V2_L_Weights", "efficientnet_b0", "efficientnet_b1", "efficientnet_b2", @@ -31,25 +44,6 @@ ] -model_urls = { - # Weights ported from https://github.com/rwightman/pytorch-image-models/ - "efficientnet_b0": "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", - "efficientnet_b1": "https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", - "efficientnet_b2": "https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", - "efficientnet_b3": "https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", - "efficientnet_b4": "https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", - # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/ - "efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", - "efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", - "efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", - # Weights trained with TorchVision - "efficientnet_v2_s": "https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth", - "efficientnet_v2_m": "https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth", - # Weights ported from TF - "efficientnet_v2_l": "https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth", -} - - @dataclass class _MBConvConfig: expand_ratio: float @@ -362,20 +356,21 @@ def forward(self, x: Tensor) -> Tensor: def _efficientnet( - arch: str, inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]], dropout: float, last_channel: Optional[int], - pretrained: bool, + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> EfficientNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs) - if pretrained: - if model_urls.get(arch, None) is None: - raise ValueError(f"No checkpoint is available for model type {arch}") - state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model @@ -434,208 +429,484 @@ def _efficientnet_conf( return inverted_residual_setting, last_channel -def efficientnet_b0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: +_COMMON_META = { + "task": "image_classification", + "categories": _IMAGENET_CATEGORIES, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet", +} + + +_COMMON_META_V1 = { + **_COMMON_META, + "architecture": "EfficientNet", + "publication_year": 2019, + "interpolation": InterpolationMode.BICUBIC, + "min_size": (1, 1), +} + + +_COMMON_META_V2 = { + **_COMMON_META, + "architecture": "EfficientNetV2", + "publication_year": 2021, + "interpolation": InterpolationMode.BILINEAR, + "min_size": (33, 33), +} + + +class EfficientNet_B0_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", + transforms=partial( + ImageClassificationEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 5288548, + "size": (224, 224), + "acc@1": 77.692, + "acc@5": 93.532, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_B1_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", + transforms=partial( + ImageClassificationEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 7794184, + "size": (240, 240), + "acc@1": 78.642, + "acc@5": 94.186, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth", + transforms=partial( + ImageClassificationEval, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR + ), + meta={ + **_COMMON_META_V1, + "num_params": 7794184, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-lr-wd-crop-tuning", + "interpolation": InterpolationMode.BILINEAR, + "size": (240, 240), + "acc@1": 79.838, + "acc@5": 94.934, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class EfficientNet_B2_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", + transforms=partial( + ImageClassificationEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 9109994, + "size": (288, 288), + "acc@1": 80.608, + "acc@5": 95.310, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_B3_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", + transforms=partial( + ImageClassificationEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 12233232, + "size": (300, 300), + "acc@1": 82.008, + "acc@5": 96.054, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_B4_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", + transforms=partial( + ImageClassificationEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 19341616, + "size": (380, 380), + "acc@1": 83.384, + "acc@5": 96.594, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_B5_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", + transforms=partial( + ImageClassificationEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 30389784, + "size": (456, 456), + "acc@1": 83.444, + "acc@5": 96.628, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_B6_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", + transforms=partial( + ImageClassificationEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 43040704, + "size": (528, 528), + "acc@1": 84.008, + "acc@5": 96.916, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_B7_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", + transforms=partial( + ImageClassificationEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 66347960, + "size": (600, 600), + "acc@1": 84.122, + "acc@5": 96.908, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_V2_S_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth", + transforms=partial( + ImageClassificationEval, + crop_size=384, + resize_size=384, + interpolation=InterpolationMode.BILINEAR, + ), + meta={ + **_COMMON_META_V2, + "num_params": 21458488, + "size": (384, 384), + "acc@1": 84.228, + "acc@5": 96.878, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_V2_M_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth", + transforms=partial( + ImageClassificationEval, + crop_size=480, + resize_size=480, + interpolation=InterpolationMode.BILINEAR, + ), + meta={ + **_COMMON_META_V2, + "num_params": 54139356, + "size": (480, 480), + "acc@1": 85.112, + "acc@5": 97.156, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_V2_L_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth", + transforms=partial( + ImageClassificationEval, + crop_size=480, + resize_size=480, + interpolation=InterpolationMode.BICUBIC, + mean=(0.5, 0.5, 0.5), + std=(0.5, 0.5, 0.5), + ), + meta={ + **_COMMON_META_V2, + "num_params": 118515272, + "size": (480, 480), + "acc@1": 85.808, + "acc@5": 97.788, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.IMAGENET1K_V1)) +def efficientnet_b0( + *, weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs a EfficientNet B0 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_B0_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_b0" - inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.0, depth_mult=1.0) - return _efficientnet(arch, inverted_residual_setting, 0.2, last_channel, pretrained, progress, **kwargs) + weights = EfficientNet_B0_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b0", width_mult=1.0, depth_mult=1.0) + return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs) -def efficientnet_b1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: +@handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.IMAGENET1K_V1)) +def efficientnet_b1( + *, weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs a EfficientNet B1 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_B1_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_b1" - inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.0, depth_mult=1.1) - return _efficientnet(arch, inverted_residual_setting, 0.2, last_channel, pretrained, progress, **kwargs) + weights = EfficientNet_B1_Weights.verify(weights) + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b1", width_mult=1.0, depth_mult=1.1) + return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs) -def efficientnet_b2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: + +@handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.IMAGENET1K_V1)) +def efficientnet_b2( + *, weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs a EfficientNet B2 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_B2_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_b2" - inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.1, depth_mult=1.2) - return _efficientnet(arch, inverted_residual_setting, 0.3, last_channel, pretrained, progress, **kwargs) + weights = EfficientNet_B2_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b2", width_mult=1.1, depth_mult=1.2) + return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs) -def efficientnet_b3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: +@handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.IMAGENET1K_V1)) +def efficientnet_b3( + *, weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs a EfficientNet B3 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_B3_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_b3" - inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.2, depth_mult=1.4) - return _efficientnet(arch, inverted_residual_setting, 0.3, last_channel, pretrained, progress, **kwargs) + weights = EfficientNet_B3_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b3", width_mult=1.2, depth_mult=1.4) + return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs) -def efficientnet_b4(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: +@handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.IMAGENET1K_V1)) +def efficientnet_b4( + *, weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs a EfficientNet B4 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_B4_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_b4" - inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.4, depth_mult=1.8) - return _efficientnet(arch, inverted_residual_setting, 0.4, last_channel, pretrained, progress, **kwargs) + weights = EfficientNet_B4_Weights.verify(weights) + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b4", width_mult=1.4, depth_mult=1.8) + return _efficientnet(inverted_residual_setting, 0.4, last_channel, weights, progress, **kwargs) -def efficientnet_b5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: + +@handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.IMAGENET1K_V1)) +def efficientnet_b5( + *, weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs a EfficientNet B5 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_B5_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_b5" - inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.6, depth_mult=2.2) + weights = EfficientNet_B5_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b5", width_mult=1.6, depth_mult=2.2) return _efficientnet( - arch, inverted_residual_setting, 0.4, last_channel, - pretrained, + weights, progress, norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs, ) -def efficientnet_b6(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: +@handle_legacy_interface(weights=("pretrained", EfficientNet_B6_Weights.IMAGENET1K_V1)) +def efficientnet_b6( + *, weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs a EfficientNet B6 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_B6_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_b6" - inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.8, depth_mult=2.6) + weights = EfficientNet_B6_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b6", width_mult=1.8, depth_mult=2.6) return _efficientnet( - arch, inverted_residual_setting, 0.5, last_channel, - pretrained, + weights, progress, norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs, ) -def efficientnet_b7(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: +@handle_legacy_interface(weights=("pretrained", EfficientNet_B7_Weights.IMAGENET1K_V1)) +def efficientnet_b7( + *, weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs a EfficientNet B7 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_B7_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_b7" - inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=2.0, depth_mult=3.1) + weights = EfficientNet_B7_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b7", width_mult=2.0, depth_mult=3.1) return _efficientnet( - arch, inverted_residual_setting, 0.5, last_channel, - pretrained, + weights, progress, norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs, ) -def efficientnet_v2_s(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: +@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_S_Weights.IMAGENET1K_V1)) +def efficientnet_v2_s( + *, weights: Optional[EfficientNet_V2_S_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs an EfficientNetV2-S architecture from `"EfficientNetV2: Smaller Models and Faster Training" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_V2_S_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_v2_s" - inverted_residual_setting, last_channel = _efficientnet_conf(arch) + weights = EfficientNet_V2_S_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_s") return _efficientnet( - arch, inverted_residual_setting, 0.2, last_channel, - pretrained, + weights, progress, norm_layer=partial(nn.BatchNorm2d, eps=1e-03), **kwargs, ) -def efficientnet_v2_m(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: +@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_M_Weights.IMAGENET1K_V1)) +def efficientnet_v2_m( + *, weights: Optional[EfficientNet_V2_M_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs an EfficientNetV2-M architecture from `"EfficientNetV2: Smaller Models and Faster Training" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_V2_M_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_v2_m" - inverted_residual_setting, last_channel = _efficientnet_conf(arch) + weights = EfficientNet_V2_M_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_m") return _efficientnet( - arch, inverted_residual_setting, 0.3, last_channel, - pretrained, + weights, progress, norm_layer=partial(nn.BatchNorm2d, eps=1e-03), **kwargs, ) -def efficientnet_v2_l(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: +@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_L_Weights.IMAGENET1K_V1)) +def efficientnet_v2_l( + *, weights: Optional[EfficientNet_V2_L_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs an EfficientNetV2-L architecture from `"EfficientNetV2: Smaller Models and Faster Training" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_V2_L_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_v2_l" - inverted_residual_setting, last_channel = _efficientnet_conf(arch) + weights = EfficientNet_V2_L_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_l") return _efficientnet( - arch, inverted_residual_setting, 0.4, last_channel, - pretrained, + weights, progress, norm_layer=partial(nn.BatchNorm2d, eps=1e-03), **kwargs, diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index 9e4c3498aab..2cac4a4fbbd 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -1,5 +1,6 @@ import warnings from collections import namedtuple +from functools import partial from typing import Optional, Tuple, List, Callable, Any import torch @@ -7,15 +8,15 @@ import torch.nn.functional as F from torch import Tensor -from .._internally_replaced_utils import load_state_dict_from_url +from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param -__all__ = ["GoogLeNet", "googlenet", "GoogLeNetOutputs", "_GoogLeNetOutputs"] -model_urls = { - # GoogLeNet ported from TensorFlow - "googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth", -} +__all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNet_Weights", "googlenet"] + GoogLeNetOutputs = namedtuple("GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"]) GoogLeNetOutputs.__annotations__ = {"logits": Tensor, "aux_logits2": Optional[Tensor], "aux_logits1": Optional[Tensor]} @@ -273,38 +274,62 @@ def forward(self, x: Tensor) -> Tensor: return F.relu(x, inplace=True) -def googlenet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> GoogLeNet: +class GoogLeNet_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/googlenet-1378be20.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + "task": "image_classification", + "architecture": "GoogLeNet", + "publication_year": 2014, + "num_params": 6624904, + "size": (224, 224), + "min_size": (15, 15), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#googlenet", + "acc@1": 69.778, + "acc@5": 89.530, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", GoogLeNet_Weights.IMAGENET1K_V1)) +def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet: r"""GoogLeNet (Inception v1) model architecture from `"Going Deeper with Convolutions" `_. The required minimum input size of the model is 15x15. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (GoogLeNet_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr aux_logits (bool): If True, adds two auxiliary branches that can improve training. Default: *False* when pretrained is True otherwise *True* transform_input (bool): If True, preprocesses the input according to the method with which it was trained on ImageNet. Default: True if ``pretrained=True``, else False. """ - if pretrained: + weights = GoogLeNet_Weights.verify(weights) + + original_aux_logits = kwargs.get("aux_logits", False) + if weights is not None: if "transform_input" not in kwargs: - kwargs["transform_input"] = True - if "aux_logits" not in kwargs: - kwargs["aux_logits"] = False - if kwargs["aux_logits"]: - warnings.warn( - "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them" - ) - original_aux_logits = kwargs["aux_logits"] - kwargs["aux_logits"] = True - kwargs["init_weights"] = False - model = GoogLeNet(**kwargs) - state_dict = load_state_dict_from_url(model_urls["googlenet"], progress=progress) - model.load_state_dict(state_dict) + _ovewrite_named_param(kwargs, "transform_input", True) + _ovewrite_named_param(kwargs, "aux_logits", True) + _ovewrite_named_param(kwargs, "init_weights", False) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + model = GoogLeNet(**kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) if not original_aux_logits: model.aux_logits = False model.aux1 = None # type: ignore[assignment] model.aux2 = None # type: ignore[assignment] - return model + else: + warnings.warn( + "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them" + ) - return GoogLeNet(**kwargs) + return model diff --git a/torchvision/models/inception.py b/torchvision/models/inception.py index c489925cb45..1628542482b 100644 --- a/torchvision/models/inception.py +++ b/torchvision/models/inception.py @@ -1,23 +1,22 @@ import warnings from collections import namedtuple +from functools import partial from typing import Callable, Any, Optional, Tuple, List import torch import torch.nn.functional as F from torch import nn, Tensor -from .._internally_replaced_utils import load_state_dict_from_url +from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param -__all__ = ["Inception3", "inception_v3", "InceptionOutputs", "_InceptionOutputs"] +__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception_V3_Weights", "inception_v3"] -model_urls = { - # Inception v3 ported from TensorFlow - "inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", -} - InceptionOutputs = namedtuple("InceptionOutputs", ["logits", "aux_logits"]) InceptionOutputs.__annotations__ = {"logits": Tensor, "aux_logits": Optional[Tensor]} @@ -407,7 +406,29 @@ def forward(self, x: Tensor) -> Tensor: return F.relu(x, inplace=True) -def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> Inception3: +class Inception_V3_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", + transforms=partial(ImageClassificationEval, crop_size=299, resize_size=342), + meta={ + "task": "image_classification", + "architecture": "InceptionV3", + "publication_year": 2015, + "num_params": 27161264, + "size": (299, 299), + "min_size": (75, 75), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#inception-v3", + "acc@1": 77.294, + "acc@5": 93.450, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", Inception_V3_Weights.IMAGENET1K_V1)) +def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3: r"""Inception v3 model architecture from `"Rethinking the Inception Architecture for Computer Vision" `_. The required minimum input size of the model is 75x75. @@ -417,28 +438,29 @@ def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) N x 3 x 299 x 299, so ensure your images are sized accordingly. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (Inception_V3_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr aux_logits (bool): If True, add an auxiliary branch that can improve training. Default: *True* transform_input (bool): If True, preprocesses the input according to the method with which it was trained on ImageNet. Default: True if ``pretrained=True``, else False. """ - if pretrained: + weights = Inception_V3_Weights.verify(weights) + + original_aux_logits = kwargs.get("aux_logits", True) + if weights is not None: if "transform_input" not in kwargs: - kwargs["transform_input"] = True - if "aux_logits" in kwargs: - original_aux_logits = kwargs["aux_logits"] - kwargs["aux_logits"] = True - else: - original_aux_logits = True - kwargs["init_weights"] = False # we are loading weights from a pretrained model - model = Inception3(**kwargs) - state_dict = load_state_dict_from_url(model_urls["inception_v3_google"], progress=progress) - model.load_state_dict(state_dict) + _ovewrite_named_param(kwargs, "transform_input", True) + _ovewrite_named_param(kwargs, "aux_logits", True) + _ovewrite_named_param(kwargs, "init_weights", False) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + model = Inception3(**kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) if not original_aux_logits: model.aux_logits = False model.AuxLogits = None - return model - return Inception3(**kwargs) + return model diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index c3d4013f30c..b6ac15c68d2 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -1,21 +1,30 @@ import warnings -from typing import Any, Dict, List +from functools import partial +from typing import Any, Dict, List, Optional import torch import torch.nn as nn from torch import Tensor -from .._internally_replaced_utils import load_state_dict_from_url +from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param -__all__ = ["MNASNet", "mnasnet0_5", "mnasnet0_75", "mnasnet1_0", "mnasnet1_3"] -_MODEL_URLS = { - "mnasnet0_5": "https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", - "mnasnet0_75": None, - "mnasnet1_0": "https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", - "mnasnet1_3": None, -} +__all__ = [ + "MNASNet", + "MNASNet0_5_Weights", + "MNASNet0_75_Weights", + "MNASNet1_0_Weights", + "MNASNet1_3_Weights", + "mnasnet0_5", + "mnasnet0_75", + "mnasnet1_0", + "mnasnet1_3", +] + # Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is # 1.0 - tensorflow. @@ -196,68 +205,123 @@ def _load_from_state_dict( ) -def _load_pretrained(model_name: str, model: nn.Module, progress: bool) -> None: - if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None: - raise ValueError(f"No checkpoint is available for model type {model_name}") - checkpoint_url = _MODEL_URLS[model_name] - model.load_state_dict(load_state_dict_from_url(checkpoint_url, progress=progress)) +_COMMON_META = { + "task": "image_classification", + "architecture": "MNASNet", + "publication_year": 2018, + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/1e100/mnasnet_trainer", +} + + +class MNASNet0_5_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 2218512, + "acc@1": 67.734, + "acc@5": 87.490, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class MNASNet0_75_Weights(WeightsEnum): + # If a default model is added here the corresponding changes need to be done in mnasnet0_75 + pass + + +class MNASNet1_0_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 4383312, + "acc@1": 73.456, + "acc@5": 91.510, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class MNASNet1_3_Weights(WeightsEnum): + # If a default model is added here the corresponding changes need to be done in mnasnet1_3 + pass + + +def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> MNASNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = MNASNet(alpha, **kwargs) -def mnasnet0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet: + if weights: + model.load_state_dict(weights.get_state_dict(progress=progress)) + + return model + + +@handle_legacy_interface(weights=("pretrained", MNASNet0_5_Weights.IMAGENET1K_V1)) +def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: r"""MNASNet with depth multiplier of 0.5 from `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (MNASNet0_5_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - model = MNASNet(0.5, **kwargs) - if pretrained: - _load_pretrained("mnasnet0_5", model, progress) - return model + weights = MNASNet0_5_Weights.verify(weights) + + return _mnasnet(0.5, weights, progress, **kwargs) -def mnasnet0_75(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet: +@handle_legacy_interface(weights=("pretrained", None)) +def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: r"""MNASNet with depth multiplier of 0.75 from `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (MNASNet0_75_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - model = MNASNet(0.75, **kwargs) - if pretrained: - _load_pretrained("mnasnet0_75", model, progress) - return model + weights = MNASNet0_75_Weights.verify(weights) + + return _mnasnet(0.75, weights, progress, **kwargs) -def mnasnet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet: +@handle_legacy_interface(weights=("pretrained", MNASNet1_0_Weights.IMAGENET1K_V1)) +def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: r"""MNASNet with depth multiplier of 1.0 from `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (MNASNet1_0_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - model = MNASNet(1.0, **kwargs) - if pretrained: - _load_pretrained("mnasnet1_0", model, progress) - return model + weights = MNASNet1_0_Weights.verify(weights) + return _mnasnet(1.0, weights, progress, **kwargs) -def mnasnet1_3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet: + +@handle_legacy_interface(weights=("pretrained", None)) +def mnasnet1_3(*, weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: r"""MNASNet with depth multiplier of 1.3 from `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (MNASNet1_3_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - model = MNASNet(1.3, **kwargs) - if pretrained: - _load_pretrained("mnasnet1_3", model, progress) - return model + weights = MNASNet1_3_Weights.verify(weights) + + return _mnasnet(1.3, weights, progress, **kwargs) diff --git a/torchvision/models/mobilenet.py b/torchvision/models/mobilenet.py index 4108305d3f5..0a270d14d3a 100644 --- a/torchvision/models/mobilenet.py +++ b/torchvision/models/mobilenet.py @@ -1,4 +1,6 @@ -from .mobilenetv2 import MobileNetV2, mobilenet_v2, __all__ as mv2_all -from .mobilenetv3 import MobileNetV3, mobilenet_v3_large, mobilenet_v3_small, __all__ as mv3_all +from .mobilenetv2 import * # noqa: F401, F403 +from .mobilenetv3 import * # noqa: F401, F403 +from .mobilenetv2 import __all__ as mv2_all +from .mobilenetv3 import __all__ as mv3_all __all__ = mv2_all + mv3_all diff --git a/torchvision/models/mobilenetv2.py b/torchvision/models/mobilenetv2.py index 930f68d13e9..acd94af4d10 100644 --- a/torchvision/models/mobilenetv2.py +++ b/torchvision/models/mobilenetv2.py @@ -1,22 +1,20 @@ import warnings +from functools import partial from typing import Callable, Any, Optional, List import torch from torch import Tensor from torch import nn -from .._internally_replaced_utils import load_state_dict_from_url from ..ops.misc import Conv2dNormActivation +from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once -from ._utils import _make_divisible +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible -__all__ = ["MobileNetV2", "mobilenet_v2"] - - -model_urls = { - "mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", -} +__all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"] # necessary for backwards compatibility @@ -195,17 +193,62 @@ def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) -def mobilenet_v2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV2: +_COMMON_META = { + "task": "image_classification", + "architecture": "MobileNetV2", + "publication_year": 2018, + "num_params": 3504872, + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, +} + + +class MobileNet_V2_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2", + "acc@1": 71.878, + "acc@5": 90.286, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning", + "acc@1": 72.154, + "acc@5": 90.822, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +@handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.IMAGENET1K_V1)) +def mobilenet_v2( + *, weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any +) -> MobileNetV2: """ Constructs a MobileNetV2 architecture from `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (MobileNet_V2_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = MobileNet_V2_Weights.verify(weights) + + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = MobileNetV2(**kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls["mobilenet_v2"], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index 530467d6d53..3a98456416d 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -5,19 +5,21 @@ import torch from torch import nn, Tensor -from .._internally_replaced_utils import load_state_dict_from_url from ..ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer +from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once -from ._utils import _make_divisible +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible -__all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"] - - -model_urls = { - "mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", - "mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", -} +__all__ = [ + "MobileNetV3", + "MobileNet_V3_Large_Weights", + "MobileNet_V3_Small_Weights", + "mobilenet_v3_large", + "mobilenet_v3_small", +] class SqueezeExcitation(SElayer): @@ -284,45 +286,106 @@ def _mobilenet_v3_conf( def _mobilenet_v3( - arch: str, inverted_residual_setting: List[InvertedResidualConfig], last_channel: int, - pretrained: bool, + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, -): +) -> MobileNetV3: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs) - if pretrained: - if model_urls.get(arch, None) is None: - raise ValueError(f"No checkpoint is available for model type {arch}") - state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model -def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3: +_COMMON_META = { + "task": "image_classification", + "architecture": "MobileNetV3", + "publication_year": 2019, + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, +} + + +class MobileNet_V3_Large_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 5483032, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small", + "acc@1": 74.042, + "acc@5": 91.340, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 5483032, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning", + "acc@1": 75.274, + "acc@5": 92.566, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class MobileNet_V3_Small_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 2542856, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small", + "acc@1": 67.668, + "acc@5": 87.402, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Large_Weights.IMAGENET1K_V1)) +def mobilenet_v3_large( + *, weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any +) -> MobileNetV3: """ Constructs a large MobileNetV3 architecture from `"Searching for MobileNetV3" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (MobileNet_V3_Large_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "mobilenet_v3_large" - inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs) - return _mobilenet_v3(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs) + weights = MobileNet_V3_Large_Weights.verify(weights) + + inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs) + return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs) -def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3: +@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Small_Weights.IMAGENET1K_V1)) +def mobilenet_v3_small( + *, weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any +) -> MobileNetV3: """ Constructs a small MobileNetV3 architecture from `"Searching for MobileNetV3" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (MobileNet_V3_Small_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "mobilenet_v3_small" - inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs) - return _mobilenet_v3(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs) + weights = MobileNet_V3_Small_Weights.verify(weights) + + inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_small", **kwargs) + return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs) diff --git a/torchvision/models/quantization/__init__.py b/torchvision/models/quantization/__init__.py index deae997a219..da8bbba3567 100644 --- a/torchvision/models/quantization/__init__.py +++ b/torchvision/models/quantization/__init__.py @@ -1,5 +1,5 @@ -from .mobilenet import * -from .resnet import * from .googlenet import * from .inception import * +from .mobilenet import * +from .resnet import * from .shufflenetv2 import * diff --git a/torchvision/models/quantization/googlenet.py b/torchvision/models/quantization/googlenet.py index 98d9382214f..befc2299c06 100644 --- a/torchvision/models/quantization/googlenet.py +++ b/torchvision/models/quantization/googlenet.py @@ -1,22 +1,25 @@ import warnings -from typing import Any, Optional +from functools import partial +from typing import Any, Optional, Union import torch import torch.nn as nn from torch import Tensor from torch.nn import functional as F -from torchvision.models.googlenet import GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet, model_urls -from ..._internally_replaced_utils import load_state_dict_from_url +from ...transforms import ImageClassificationEval, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _IMAGENET_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_named_param +from ..googlenet import GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet, GoogLeNet_Weights from .utils import _fuse_modules, _replace_relu, quantize_model -__all__ = ["QuantizableGoogLeNet", "googlenet"] - -quant_model_urls = { - # fp32 GoogLeNet ported from TensorFlow, with weights quantized in PyTorch - "googlenet_fbgemm": "https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth", -} +__all__ = [ + "QuantizableGoogLeNet", + "GoogLeNet_QuantizedWeights", + "googlenet", +] class QuantizableBasicConv2d(BasicConv2d): @@ -103,8 +106,41 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: m.fuse_model(is_qat) +class GoogLeNet_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + "task": "image_classification", + "architecture": "GoogLeNet", + "publication_year": 2014, + "num_params": 6624904, + "size": (224, 224), + "min_size": (15, 15), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "backend": "fbgemm", + "quantization": "ptq", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", + "unquantized": GoogLeNet_Weights.IMAGENET1K_V1, + "acc@1": 69.826, + "acc@5": 89.404, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: GoogLeNet_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else GoogLeNet_Weights.IMAGENET1K_V1, + ) +) def googlenet( - pretrained: bool = False, + *, + weights: Optional[Union[GoogLeNet_QuantizedWeights, GoogLeNet_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -117,49 +153,38 @@ def googlenet( GPU inference is not yet supported Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (GoogLeNet_QuantizedWeights or GoogLeNet_Weights, optional): The pretrained + weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model - aux_logits (bool): If True, adds two auxiliary branches that can improve training. - Default: *False* when pretrained is True otherwise *True* - transform_input (bool): If True, preprocesses the input according to the method with which it - was trained on ImageNet. Default: True if ``pretrained=True``, else False. """ - if pretrained: + weights = (GoogLeNet_QuantizedWeights if quantize else GoogLeNet_Weights).verify(weights) + + original_aux_logits = kwargs.get("aux_logits", False) + if weights is not None: if "transform_input" not in kwargs: - kwargs["transform_input"] = True - if "aux_logits" not in kwargs: - kwargs["aux_logits"] = False - if kwargs["aux_logits"]: - warnings.warn( - "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them" - ) - original_aux_logits = kwargs["aux_logits"] - kwargs["aux_logits"] = True - kwargs["init_weights"] = False + _ovewrite_named_param(kwargs, "transform_input", True) + _ovewrite_named_param(kwargs, "aux_logits", True) + _ovewrite_named_param(kwargs, "init_weights", False) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "fbgemm") model = QuantizableGoogLeNet(**kwargs) _replace_relu(model) - if quantize: - # TODO use pretrained as a string to specify the backend - backend = "fbgemm" quantize_model(model, backend) - else: - assert pretrained in [True, False] - - if pretrained: - if quantize: - model_url = quant_model_urls["googlenet_" + backend] - else: - model_url = model_urls["googlenet"] - - state_dict = load_state_dict_from_url(model_url, progress=progress) - - model.load_state_dict(state_dict) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) if not original_aux_logits: model.aux_logits = False model.aux1 = None # type: ignore[assignment] model.aux2 = None # type: ignore[assignment] + else: + warnings.warn( + "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them" + ) + return model diff --git a/torchvision/models/quantization/inception.py b/torchvision/models/quantization/inception.py index 27d021428b9..697d99d4027 100644 --- a/torchvision/models/quantization/inception.py +++ b/torchvision/models/quantization/inception.py @@ -1,29 +1,28 @@ import warnings -from typing import Any, List, Optional +from functools import partial +from typing import Any, List, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from torchvision.models import inception as inception_module -from torchvision.models.inception import InceptionOutputs +from torchvision.models.inception import InceptionOutputs, Inception_V3_Weights -from ..._internally_replaced_utils import load_state_dict_from_url +from ...transforms import ImageClassificationEval, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _IMAGENET_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_named_param from .utils import _fuse_modules, _replace_relu, quantize_model __all__ = [ "QuantizableInception3", + "Inception_V3_QuantizedWeights", "inception_v3", ] -quant_model_urls = { - # fp32 weights ported from TensorFlow, quantized in PyTorch - "inception_v3_google_fbgemm": "https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth" -} - - class QuantizableBasicConv2d(inception_module.BasicConv2d): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -173,8 +172,41 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: m.fuse_model(is_qat) +class Inception_V3_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth", + transforms=partial(ImageClassificationEval, crop_size=299, resize_size=342), + meta={ + "task": "image_classification", + "architecture": "InceptionV3", + "publication_year": 2015, + "num_params": 27161264, + "size": (299, 299), + "min_size": (75, 75), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "backend": "fbgemm", + "quantization": "ptq", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", + "unquantized": Inception_V3_Weights.IMAGENET1K_V1, + "acc@1": 77.176, + "acc@5": 93.354, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: Inception_V3_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else Inception_V3_Weights.IMAGENET1K_V1, + ) +) def inception_v3( - pretrained: bool = False, + *, + weights: Optional[Union[Inception_V3_QuantizedWeights, Inception_V3_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -191,48 +223,35 @@ def inception_v3( GPU inference is not yet supported Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (Inception_V3_QuantizedWeights or Inception_V3_Weights, optional): The pretrained + weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model - aux_logits (bool): If True, add an auxiliary branch that can improve training. - Default: *True* - transform_input (bool): If True, preprocesses the input according to the method with which it - was trained on ImageNet. Default: True if ``pretrained=True``, else False. """ - if pretrained: + weights = (Inception_V3_QuantizedWeights if quantize else Inception_V3_Weights).verify(weights) + + original_aux_logits = kwargs.get("aux_logits", False) + if weights is not None: if "transform_input" not in kwargs: - kwargs["transform_input"] = True - if "aux_logits" in kwargs: - original_aux_logits = kwargs["aux_logits"] - kwargs["aux_logits"] = True - else: - original_aux_logits = False + _ovewrite_named_param(kwargs, "transform_input", True) + _ovewrite_named_param(kwargs, "aux_logits", True) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "fbgemm") model = QuantizableInception3(**kwargs) _replace_relu(model) - if quantize: - # TODO use pretrained as a string to specify the backend - backend = "fbgemm" quantize_model(model, backend) - else: - assert pretrained in [True, False] - - if pretrained: - if quantize: - if not original_aux_logits: - model.aux_logits = False - model.AuxLogits = None - model_url = quant_model_urls["inception_v3_google_" + backend] - else: - model_url = inception_module.model_urls["inception_v3_google"] - - state_dict = load_state_dict_from_url(model_url, progress=progress) - model.load_state_dict(state_dict) + if weights is not None: + if quantize and not original_aux_logits: + model.aux_logits = False + model.AuxLogits = None + model.load_state_dict(weights.get_state_dict(progress=progress)) + if not quantize and not original_aux_logits: + model.aux_logits = False + model.AuxLogits = None - if not quantize: - if not original_aux_logits: - model.aux_logits = False - model.AuxLogits = None return model diff --git a/torchvision/models/quantization/mobilenet.py b/torchvision/models/quantization/mobilenet.py index 8f2c42db640..0a270d14d3a 100644 --- a/torchvision/models/quantization/mobilenet.py +++ b/torchvision/models/quantization/mobilenet.py @@ -1,4 +1,6 @@ -from .mobilenetv2 import QuantizableMobileNetV2, mobilenet_v2, __all__ as mv2_all -from .mobilenetv3 import QuantizableMobileNetV3, mobilenet_v3_large, __all__ as mv3_all +from .mobilenetv2 import * # noqa: F401, F403 +from .mobilenetv3 import * # noqa: F401, F403 +from .mobilenetv2 import __all__ as mv2_all +from .mobilenetv3 import __all__ as mv3_all __all__ = mv2_all + mv3_all diff --git a/torchvision/models/quantization/mobilenetv2.py b/torchvision/models/quantization/mobilenetv2.py index 8cd9f16d13e..40f5cb544fd 100644 --- a/torchvision/models/quantization/mobilenetv2.py +++ b/torchvision/models/quantization/mobilenetv2.py @@ -1,20 +1,24 @@ -from typing import Any, Optional +from functools import partial +from typing import Any, Optional, Union from torch import Tensor from torch import nn from torch.ao.quantization import QuantStub, DeQuantStub -from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls +from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, MobileNet_V2_Weights -from ..._internally_replaced_utils import load_state_dict_from_url from ...ops.misc import Conv2dNormActivation +from ...transforms import ImageClassificationEval, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _IMAGENET_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_named_param from .utils import _fuse_modules, _replace_relu, quantize_model -__all__ = ["QuantizableMobileNetV2", "mobilenet_v2"] - -quant_model_urls = { - "mobilenet_v2_qnnpack": "https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth" -} +__all__ = [ + "QuantizableMobileNetV2", + "MobileNet_V2_QuantizedWeights", + "mobilenet_v2", +] class QuantizableInvertedResidual(InvertedResidual): @@ -60,8 +64,41 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: m.fuse_model(is_qat) +class MobileNet_V2_QuantizedWeights(WeightsEnum): + IMAGENET1K_QNNPACK_V1 = Weights( + url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + "task": "image_classification", + "architecture": "MobileNetV2", + "publication_year": 2018, + "num_params": 3504872, + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "backend": "qnnpack", + "quantization": "qat", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv2", + "unquantized": MobileNet_V2_Weights.IMAGENET1K_V1, + "acc@1": 71.658, + "acc@5": 90.150, + }, + ) + DEFAULT = IMAGENET1K_QNNPACK_V1 + + +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: MobileNet_V2_QuantizedWeights.IMAGENET1K_QNNPACK_V1 + if kwargs.get("quantize", False) + else MobileNet_V2_Weights.IMAGENET1K_V1, + ) +) def mobilenet_v2( - pretrained: bool = False, + *, + weights: Optional[Union[MobileNet_V2_QuantizedWeights, MobileNet_V2_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -76,27 +113,25 @@ def mobilenet_v2( GPU inference is not yet supported Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet. - progress (bool): If True, displays a progress bar of the download to stderr - quantize(bool): If True, returns a quantized model, else returns a float model + weights (GoogLeNet_QuantizedWeights or GoogLeNet_Weights, optional): The pretrained + weights for the model + progress (bool): If True, displays a progress bar of the download to stderr + quantize(bool): If True, returns a quantized model, else returns a float model """ + weights = (MobileNet_V2_QuantizedWeights if quantize else MobileNet_V2_Weights).verify(weights) + + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "qnnpack") + model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs) _replace_relu(model) - if quantize: - # TODO use pretrained as a string to specify the backend - backend = "qnnpack" quantize_model(model, backend) - else: - assert pretrained in [True, False] - - if pretrained: - if quantize: - model_url = quant_model_urls["mobilenet_v2_" + backend] - else: - model_url = model_urls["mobilenet_v2"] - state_dict = load_state_dict_from_url(model_url, progress=progress) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) - model.load_state_dict(state_dict) return model diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 4d7e2f7baad..4b79b7f26ae 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -1,20 +1,30 @@ -from typing import Any, List, Optional +from functools import partial +from typing import Any, List, Optional, Union import torch from torch import nn, Tensor from torch.ao.quantization import QuantStub, DeQuantStub -from ..._internally_replaced_utils import load_state_dict_from_url from ...ops.misc import Conv2dNormActivation, SqueezeExcitation -from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3, model_urls, _mobilenet_v3_conf +from ...transforms import ImageClassificationEval, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _IMAGENET_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_named_param +from ..mobilenetv3 import ( + InvertedResidual, + InvertedResidualConfig, + MobileNetV3, + _mobilenet_v3_conf, + MobileNet_V3_Large_Weights, +) from .utils import _fuse_modules, _replace_relu -__all__ = ["QuantizableMobileNetV3", "mobilenet_v3_large"] - -quant_model_urls = { - "mobilenet_v3_large_qnnpack": "https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", -} +__all__ = [ + "QuantizableMobileNetV3", + "MobileNet_V3_Large_QuantizedWeights", + "mobilenet_v3_large", +] class QuantizableSqueezeExcitation(SqueezeExcitation): @@ -112,47 +122,73 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: m.fuse_model(is_qat) -def _load_weights(arch: str, model: QuantizableMobileNetV3, model_url: Optional[str], progress: bool) -> None: - if model_url is None: - raise ValueError(f"No checkpoint is available for {arch}") - state_dict = load_state_dict_from_url(model_url, progress=progress) - model.load_state_dict(state_dict) - - def _mobilenet_v3_model( - arch: str, inverted_residual_setting: List[InvertedResidualConfig], last_channel: int, - pretrained: bool, + weights: Optional[WeightsEnum], progress: bool, quantize: bool, **kwargs: Any, ) -> QuantizableMobileNetV3: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "qnnpack") model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs) _replace_relu(model) if quantize: - backend = "qnnpack" - model.fuse_model(is_qat=True) model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend) torch.ao.quantization.prepare_qat(model, inplace=True) - if pretrained: - _load_weights(arch, model, quant_model_urls.get(arch + "_" + backend, None), progress) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + if quantize: torch.ao.quantization.convert(model, inplace=True) model.eval() - else: - if pretrained: - _load_weights(arch, model, model_urls.get(arch, None), progress) return model +class MobileNet_V3_Large_QuantizedWeights(WeightsEnum): + IMAGENET1K_QNNPACK_V1 = Weights( + url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + "task": "image_classification", + "architecture": "MobileNetV3", + "publication_year": 2019, + "num_params": 5483032, + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "backend": "qnnpack", + "quantization": "qat", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv3", + "unquantized": MobileNet_V3_Large_Weights.IMAGENET1K_V1, + "acc@1": 73.004, + "acc@5": 90.858, + }, + ) + DEFAULT = IMAGENET1K_QNNPACK_V1 + + +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: MobileNet_V3_Large_QuantizedWeights.IMAGENET1K_QNNPACK_V1 + if kwargs.get("quantize", False) + else MobileNet_V3_Large_Weights.IMAGENET1K_V1, + ) +) def mobilenet_v3_large( - pretrained: bool = False, + *, + weights: Optional[Union[MobileNet_V3_Large_QuantizedWeights, MobileNet_V3_Large_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -166,10 +202,12 @@ def mobilenet_v3_large( GPU inference is not yet supported Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet. - progress (bool): If True, displays a progress bar of the download to stderr - quantize (bool): If True, returns a quantized model, else returns a float model + weights (MobileNet_V3_Large_QuantizedWeights or MobileNet_V3_Large_Weights, optional): The pretrained + weights for the model + progress (bool): If True, displays a progress bar of the download to stderr + quantize (bool): If True, returns a quantized model, else returns a float model """ - arch = "mobilenet_v3_large" - inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs) - return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, quantize, **kwargs) + weights = (MobileNet_V3_Large_QuantizedWeights if quantize else MobileNet_V3_Large_Weights).verify(weights) + + inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs) + return _mobilenet_v3_model(inverted_residual_setting, last_channel, weights, progress, quantize, **kwargs) diff --git a/torchvision/models/quantization/resnet.py b/torchvision/models/quantization/resnet.py index f55aa0e103c..666b1b23163 100644 --- a/torchvision/models/quantization/resnet.py +++ b/torchvision/models/quantization/resnet.py @@ -1,21 +1,34 @@ +from functools import partial from typing import Any, Type, Union, List, Optional import torch import torch.nn as nn from torch import Tensor -from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls - -from ..._internally_replaced_utils import load_state_dict_from_url +from torchvision.models.resnet import ( + Bottleneck, + BasicBlock, + ResNet, + ResNet18_Weights, + ResNet50_Weights, + ResNeXt101_32X8D_Weights, +) + +from ...transforms import ImageClassificationEval, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _IMAGENET_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_named_param from .utils import _fuse_modules, _replace_relu, quantize_model -__all__ = ["QuantizableResNet", "resnet18", "resnet50", "resnext101_32x8d"] - -quant_model_urls = { - "resnet18_fbgemm": "https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth", - "resnet50_fbgemm": "https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", - "resnext101_32x8d_fbgemm": "https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth", -} +__all__ = [ + "QuantizableResNet", + "ResNet18_QuantizedWeights", + "ResNet50_QuantizedWeights", + "ResNeXt101_32X8D_QuantizedWeights", + "resnet18", + "resnet50", + "resnext101_32x8d", +] class QuantizableBasicBlock(BasicBlock): @@ -109,38 +122,130 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: def _resnet( - arch: str, block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]], layers: List[int], - pretrained: bool, + weights: Optional[WeightsEnum], progress: bool, quantize: bool, **kwargs: Any, ) -> QuantizableResNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "fbgemm") model = QuantizableResNet(block, layers, **kwargs) _replace_relu(model) if quantize: - # TODO use pretrained as a string to specify the backend - backend = "fbgemm" quantize_model(model, backend) - else: - assert pretrained in [True, False] - if pretrained: - if quantize: - model_url = quant_model_urls[arch + "_" + backend] - else: - model_url = model_urls[arch] + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) - state_dict = load_state_dict_from_url(model_url, progress=progress) - - model.load_state_dict(state_dict) return model +_COMMON_META = { + "task": "image_classification", + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "backend": "fbgemm", + "quantization": "ptq", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", +} + + +class ResNet18_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 11689512, + "unquantized": ResNet18_Weights.IMAGENET1K_V1, + "acc@1": 69.494, + "acc@5": 88.882, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +class ResNet50_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 25557032, + "unquantized": ResNet50_Weights.IMAGENET1K_V1, + "acc@1": 75.920, + "acc@5": 92.814, + }, + ) + IMAGENET1K_FBGEMM_V2 = Weights( + url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 25557032, + "unquantized": ResNet50_Weights.IMAGENET1K_V2, + "acc@1": 80.282, + "acc@5": 94.976, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V2 + + +class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNeXt", + "publication_year": 2016, + "num_params": 88791336, + "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V1, + "acc@1": 78.986, + "acc@5": 94.480, + }, + ) + IMAGENET1K_FBGEMM_V2 = Weights( + url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "architecture": "ResNeXt", + "publication_year": 2016, + "num_params": 88791336, + "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V2, + "acc@1": 82.574, + "acc@5": 96.132, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V2 + + +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ResNet18_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ResNet18_Weights.IMAGENET1K_V1, + ) +) def resnet18( - pretrained: bool = False, + *, + weights: Optional[Union[ResNet18_QuantizedWeights, ResNet18_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -149,33 +254,56 @@ def resnet18( `"Deep Residual Learning for Image Recognition" `_ Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ResNet18_QuantizedWeights or ResNet18_Weights, optional): The pretrained + weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ - return _resnet("resnet18", QuantizableBasicBlock, [2, 2, 2, 2], pretrained, progress, quantize, **kwargs) + weights = (ResNet18_QuantizedWeights if quantize else ResNet18_Weights).verify(weights) + + return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs) +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ResNet50_Weights.IMAGENET1K_V1, + ) +) def resnet50( - pretrained: bool = False, + *, + weights: Optional[Union[ResNet50_QuantizedWeights, ResNet50_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, ) -> QuantizableResNet: - r"""ResNet-50 model from `"Deep Residual Learning for Image Recognition" `_ Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ResNet50_QuantizedWeights or ResNet50_Weights, optional): The pretrained + weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ - return _resnet("resnet50", QuantizableBottleneck, [3, 4, 6, 3], pretrained, progress, quantize, **kwargs) + weights = (ResNet50_QuantizedWeights if quantize else ResNet50_Weights).verify(weights) + return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs) + +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ResNeXt101_32X8D_Weights.IMAGENET1K_V1, + ) +) def resnext101_32x8d( - pretrained: bool = False, + *, + weights: Optional[Union[ResNeXt101_32X8D_QuantizedWeights, ResNeXt101_32X8D_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -184,10 +312,13 @@ def resnext101_32x8d( `"Aggregated Residual Transformation for Deep Neural Networks" `_ Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ResNeXt101_32X8D_QuantizedWeights or ResNeXt101_32X8D_Weights, optional): The pretrained + weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ - kwargs["groups"] = 32 - kwargs["width_per_group"] = 8 - return _resnet("resnext101_32x8d", QuantizableBottleneck, [3, 4, 23, 3], pretrained, progress, quantize, **kwargs) + weights = (ResNeXt101_32X8D_QuantizedWeights if quantize else ResNeXt101_32X8D_Weights).verify(weights) + + _ovewrite_named_param(kwargs, "groups", 32) + _ovewrite_named_param(kwargs, "width_per_group", 8) + return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs) diff --git a/torchvision/models/quantization/shufflenetv2.py b/torchvision/models/quantization/shufflenetv2.py index 9d25315ffa0..c5bfe698636 100644 --- a/torchvision/models/quantization/shufflenetv2.py +++ b/torchvision/models/quantization/shufflenetv2.py @@ -1,24 +1,27 @@ -from typing import Any, Optional +from functools import partial +from typing import Any, List, Optional, Union import torch import torch.nn as nn from torch import Tensor from torchvision.models import shufflenetv2 -from ..._internally_replaced_utils import load_state_dict_from_url +from ...transforms import ImageClassificationEval, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _IMAGENET_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_named_param +from ..shufflenetv2 import ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights from .utils import _fuse_modules, _replace_relu, quantize_model + __all__ = [ "QuantizableShuffleNetV2", + "ShuffleNet_V2_X0_5_QuantizedWeights", + "ShuffleNet_V2_X1_0_QuantizedWeights", "shufflenet_v2_x0_5", "shufflenet_v2_x1_0", ] -quant_model_urls = { - "shufflenetv2_x0.5_fbgemm": "https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth", - "shufflenetv2_x1.0_fbgemm": "https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth", -} - class QuantizableInvertedResidual(shufflenetv2.InvertedResidual): def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -73,39 +76,86 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: def _shufflenetv2( - arch: str, - pretrained: bool, + stages_repeats: List[int], + stages_out_channels: List[int], + *, + weights: Optional[WeightsEnum], progress: bool, quantize: bool, - *args: Any, **kwargs: Any, ) -> QuantizableShuffleNetV2: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "fbgemm") - model = QuantizableShuffleNetV2(*args, **kwargs) + model = QuantizableShuffleNetV2(stages_repeats, stages_out_channels, **kwargs) _replace_relu(model) - if quantize: - # TODO use pretrained as a string to specify the backend - backend = "fbgemm" quantize_model(model, backend) - else: - assert pretrained in [True, False] - if pretrained: - model_url: Optional[str] = None - if quantize: - model_url = quant_model_urls[arch + "_" + backend] - else: - model_url = shufflenetv2.model_urls[arch] - - state_dict = load_state_dict_from_url(model_url, progress=progress) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) - model.load_state_dict(state_dict) return model +_COMMON_META = { + "task": "image_classification", + "architecture": "ShuffleNetV2", + "publication_year": 2018, + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "backend": "fbgemm", + "quantization": "ptq", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", +} + + +class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 1366792, + "unquantized": ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1, + "acc@1": 57.972, + "acc@5": 79.780, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 2278604, + "unquantized": ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1, + "acc@1": 68.360, + "acc@5": 87.582, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ShuffleNet_V2_X0_5_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1, + ) +) def shufflenet_v2_x0_5( - pretrained: bool = False, + *, + weights: Optional[Union[ShuffleNet_V2_X0_5_QuantizedWeights, ShuffleNet_V2_X0_5_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -116,17 +166,28 @@ def shufflenet_v2_x0_5( `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ShuffleNet_V2_X0_5_QuantizedWeights or ShuffleNet_V2_X0_5_Weights, optional): The pretrained + weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ + weights = (ShuffleNet_V2_X0_5_QuantizedWeights if quantize else ShuffleNet_V2_X0_5_Weights).verify(weights) return _shufflenetv2( - "shufflenetv2_x0.5", pretrained, progress, quantize, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs + [4, 8, 4], [24, 48, 96, 192, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs ) +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ShuffleNet_V2_X1_0_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1, + ) +) def shufflenet_v2_x1_0( - pretrained: bool = False, + *, + weights: Optional[Union[ShuffleNet_V2_X1_0_QuantizedWeights, ShuffleNet_V2_X1_0_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -137,10 +198,12 @@ def shufflenet_v2_x1_0( `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ShuffleNet_V2_X1_0_QuantizedWeights or ShuffleNet_V2_X1_0_Weights, optional): The pretrained + weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ + weights = (ShuffleNet_V2_X1_0_QuantizedWeights if quantize else ShuffleNet_V2_X1_0_Weights).verify(weights) return _shufflenetv2( - "shufflenetv2_x1.0", pretrained, progress, quantize, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs + [4, 8, 4], [24, 116, 232, 464, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs ) diff --git a/torchvision/models/regnet.py b/torchvision/models/regnet.py index 74abd20b237..1015c21b858 100644 --- a/torchvision/models/regnet.py +++ b/torchvision/models/regnet.py @@ -1,8 +1,3 @@ -# Modified from -# https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/anynet.py -# https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py - - import math from collections import OrderedDict from functools import partial @@ -11,14 +6,31 @@ import torch from torch import nn, Tensor -from .._internally_replaced_utils import load_state_dict_from_url from ..ops.misc import Conv2dNormActivation, SqueezeExcitation +from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once -from ._utils import _make_divisible +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible __all__ = [ "RegNet", + "RegNet_Y_400MF_Weights", + "RegNet_Y_800MF_Weights", + "RegNet_Y_1_6GF_Weights", + "RegNet_Y_3_2GF_Weights", + "RegNet_Y_8GF_Weights", + "RegNet_Y_16GF_Weights", + "RegNet_Y_32GF_Weights", + "RegNet_Y_128GF_Weights", + "RegNet_X_400MF_Weights", + "RegNet_X_800MF_Weights", + "RegNet_X_1_6GF_Weights", + "RegNet_X_3_2GF_Weights", + "RegNet_X_8GF_Weights", + "RegNet_X_16GF_Weights", + "RegNet_X_32GF_Weights", "regnet_y_400mf", "regnet_y_800mf", "regnet_y_1_6gf", @@ -37,24 +49,6 @@ ] -model_urls = { - "regnet_y_400mf": "https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", - "regnet_y_800mf": "https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", - "regnet_y_1_6gf": "https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", - "regnet_y_3_2gf": "https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", - "regnet_y_8gf": "https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", - "regnet_y_16gf": "https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", - "regnet_y_32gf": "https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", - "regnet_x_400mf": "https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", - "regnet_x_800mf": "https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", - "regnet_x_1_6gf": "https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", - "regnet_x_3_2gf": "https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", - "regnet_x_8gf": "https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", - "regnet_x_16gf": "https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", - "regnet_x_32gf": "https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", -} - - class SimpleStemIN(Conv2dNormActivation): """Simple stem for ImageNet: 3x3, BN, ReLU.""" @@ -390,219 +384,652 @@ def forward(self, x: Tensor) -> Tensor: return x -def _regnet(arch: str, block_params: BlockParams, pretrained: bool, progress: bool, **kwargs: Any) -> RegNet: +def _regnet( + block_params: BlockParams, + weights: Optional[WeightsEnum], + progress: bool, + **kwargs: Any, +) -> RegNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + norm_layer = kwargs.pop("norm_layer", partial(nn.BatchNorm2d, eps=1e-05, momentum=0.1)) model = RegNet(block_params, norm_layer=norm_layer, **kwargs) - if pretrained: - if arch not in model_urls: - raise ValueError(f"No checkpoint is available for model type {arch}") - state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model -def regnet_y_400mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +_COMMON_META = { + "task": "image_classification", + "architecture": "RegNet", + "publication_year": 2020, + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, +} + + +class RegNet_Y_400MF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 4344144, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", + "acc@1": 74.046, + "acc@5": 91.716, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_y_400mf-e6988f5f.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 4344144, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 75.804, + "acc@5": 92.742, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_Y_800MF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 6432512, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", + "acc@1": 76.420, + "acc@5": 93.136, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_y_800mf-58fc7688.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 6432512, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 78.828, + "acc@5": 94.502, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_Y_1_6GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 11202430, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", + "acc@1": 77.950, + "acc@5": 93.966, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_y_1_6gf-0d7bc02a.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 11202430, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 80.876, + "acc@5": 95.444, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_Y_3_2GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 19436338, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", + "acc@1": 78.948, + "acc@5": 94.576, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_y_3_2gf-9180c971.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 19436338, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 81.982, + "acc@5": 95.972, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_Y_8GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 39381472, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", + "acc@1": 80.032, + "acc@5": 95.048, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_y_8gf-dc2b1b54.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 39381472, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 82.828, + "acc@5": 96.330, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_Y_16GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 83590140, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models", + "acc@1": 80.424, + "acc@5": 95.240, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_y_16gf-3e4a00f9.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 83590140, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 82.886, + "acc@5": 96.328, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_Y_32GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 145046770, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models", + "acc@1": 80.878, + "acc@5": 95.340, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_y_32gf-8db6d4b5.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 145046770, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 83.368, + "acc@5": 96.498, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_Y_128GF_Weights(WeightsEnum): + # weights are not available yet. + pass + + +class RegNet_X_400MF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 5495976, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", + "acc@1": 72.834, + "acc@5": 90.950, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_x_400mf-62229a5f.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 5495976, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", + "acc@1": 74.864, + "acc@5": 92.322, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_X_800MF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 7259656, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", + "acc@1": 75.212, + "acc@5": 92.348, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_x_800mf-94a99ebd.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 7259656, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", + "acc@1": 77.522, + "acc@5": 93.826, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_X_1_6GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 9190136, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", + "acc@1": 77.040, + "acc@5": 93.440, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_x_1_6gf-a12f2b72.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 9190136, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", + "acc@1": 79.668, + "acc@5": 94.922, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_X_3_2GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 15296552, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", + "acc@1": 78.364, + "acc@5": 93.992, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_x_3_2gf-7071aa85.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 15296552, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 81.196, + "acc@5": 95.430, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_X_8GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 39572648, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", + "acc@1": 79.344, + "acc@5": 94.686, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_x_8gf-2b70d774.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 39572648, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 81.682, + "acc@5": 95.678, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_X_16GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 54278536, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", + "acc@1": 80.058, + "acc@5": 94.944, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_x_16gf-ba3796d7.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 54278536, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 82.716, + "acc@5": 96.196, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_X_32GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 107811560, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models", + "acc@1": 80.622, + "acc@5": 95.248, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_x_32gf-6eb8fdc6.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 107811560, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 83.014, + "acc@5": 96.288, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +@handle_legacy_interface(weights=("pretrained", RegNet_Y_400MF_Weights.IMAGENET1K_V1)) +def regnet_y_400mf(*, weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_400MF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_Y_400MF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_Y_400MF_Weights.verify(weights) + params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs) - return _regnet("regnet_y_400mf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_y_800mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_Y_800MF_Weights.IMAGENET1K_V1)) +def regnet_y_800mf(*, weights: Optional[RegNet_Y_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_800MF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_Y_800MF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_Y_800MF_Weights.verify(weights) + params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs) - return _regnet("regnet_y_800mf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_y_1_6gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_Y_1_6GF_Weights.IMAGENET1K_V1)) +def regnet_y_1_6gf(*, weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_1.6GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_Y_1_6GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_Y_1_6GF_Weights.verify(weights) + params = BlockParams.from_init_params( depth=27, w_0=48, w_a=20.71, w_m=2.65, group_width=24, se_ratio=0.25, **kwargs ) - return _regnet("regnet_y_1_6gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_y_3_2gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_Y_3_2GF_Weights.IMAGENET1K_V1)) +def regnet_y_3_2gf(*, weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_3.2GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_Y_3_2GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_Y_3_2GF_Weights.verify(weights) + params = BlockParams.from_init_params( depth=21, w_0=80, w_a=42.63, w_m=2.66, group_width=24, se_ratio=0.25, **kwargs ) - return _regnet("regnet_y_3_2gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_y_8gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_Y_8GF_Weights.IMAGENET1K_V1)) +def regnet_y_8gf(*, weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_8GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_Y_8GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_Y_8GF_Weights.verify(weights) + params = BlockParams.from_init_params( depth=17, w_0=192, w_a=76.82, w_m=2.19, group_width=56, se_ratio=0.25, **kwargs ) - return _regnet("regnet_y_8gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_y_16gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_Y_16GF_Weights.IMAGENET1K_V1)) +def regnet_y_16gf(*, weights: Optional[RegNet_Y_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_16GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_Y_16GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_Y_16GF_Weights.verify(weights) + params = BlockParams.from_init_params( depth=18, w_0=200, w_a=106.23, w_m=2.48, group_width=112, se_ratio=0.25, **kwargs ) - return _regnet("regnet_y_16gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_y_32gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_Y_32GF_Weights.IMAGENET1K_V1)) +def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_32GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_Y_32GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_Y_32GF_Weights.verify(weights) + params = BlockParams.from_init_params( depth=20, w_0=232, w_a=115.89, w_m=2.53, group_width=232, se_ratio=0.25, **kwargs ) - return _regnet("regnet_y_32gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_y_128gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", None)) +def regnet_y_128gf(*, weights: Optional[RegNet_Y_128GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_128GF architecture from `"Designing Network Design Spaces" `_. NOTE: Pretrained weights are not available for this model. + + Args: + weights (RegNet_Y_128GF_Weights, optional): The pretrained weights for the model + progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_Y_128GF_Weights.verify(weights) + params = BlockParams.from_init_params( depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25, **kwargs ) - return _regnet("regnet_y_128gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_x_400mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_X_400MF_Weights.IMAGENET1K_V1)) +def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_400MF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_X_400MF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_X_400MF_Weights.verify(weights) + params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs) - return _regnet("regnet_x_400mf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_x_800mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_X_800MF_Weights.IMAGENET1K_V1)) +def regnet_x_800mf(*, weights: Optional[RegNet_X_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_800MF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_X_800MF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_X_800MF_Weights.verify(weights) + params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs) - return _regnet("regnet_x_800mf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_x_1_6gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_X_1_6GF_Weights.IMAGENET1K_V1)) +def regnet_x_1_6gf(*, weights: Optional[RegNet_X_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_1.6GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_X_1_6GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_X_1_6GF_Weights.verify(weights) + params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs) - return _regnet("regnet_x_1_6gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_x_3_2gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_X_3_2GF_Weights.IMAGENET1K_V1)) +def regnet_x_3_2gf(*, weights: Optional[RegNet_X_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_3.2GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_X_3_2GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_X_3_2GF_Weights.verify(weights) + params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs) - return _regnet("regnet_x_3_2gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_x_8gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_X_8GF_Weights.IMAGENET1K_V1)) +def regnet_x_8gf(*, weights: Optional[RegNet_X_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_8GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_X_8GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_X_8GF_Weights.verify(weights) + params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs) - return _regnet("regnet_x_8gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_x_16gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_X_16GF_Weights.IMAGENET1K_V1)) +def regnet_x_16gf(*, weights: Optional[RegNet_X_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_16GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_X_16GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_X_16GF_Weights.verify(weights) + params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs) - return _regnet("regnet_x_16gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_x_32gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_X_32GF_Weights.IMAGENET1K_V1)) +def regnet_x_32gf(*, weights: Optional[RegNet_X_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_32GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_X_32GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs) - return _regnet("regnet_x_32gf", params, pretrained, progress, **kwargs) + weights = RegNet_X_32GF_Weights.verify(weights) - -# TODO(kazhang): Add RegNetZ_500MF and RegNetZ_4GF + params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs) + return _regnet(params, weights, progress, **kwargs) diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index b0bb8d13ade..159749df006 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -1,15 +1,28 @@ +from functools import partial from typing import Type, Any, Callable, Union, List, Optional import torch import torch.nn as nn from torch import Tensor -from .._internally_replaced_utils import load_state_dict_from_url +from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param __all__ = [ "ResNet", + "ResNet18_Weights", + "ResNet34_Weights", + "ResNet50_Weights", + "ResNet101_Weights", + "ResNet152_Weights", + "ResNeXt50_32X4D_Weights", + "ResNeXt101_32X8D_Weights", + "Wide_ResNet50_2_Weights", + "Wide_ResNet101_2_Weights", "resnet18", "resnet34", "resnet50", @@ -22,19 +35,6 @@ ] -model_urls = { - "resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth", - "resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth", - "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth", - "resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth", - "resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth", - "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", - "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", - "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", - "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", -} - - def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: """3x3 convolution with padding""" return nn.Conv2d( @@ -284,102 +284,386 @@ def forward(self, x: Tensor) -> Tensor: def _resnet( - arch: str, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], - pretrained: bool, + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> ResNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = ResNet(block, layers, **kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model -def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: +_COMMON_META = { + "task": "image_classification", + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, +} + + +class ResNet18_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnet18-f37072fd.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 11689512, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", + "acc@1": 69.758, + "acc@5": 89.078, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ResNet34_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnet34-b627a593.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 21797672, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", + "acc@1": 73.314, + "acc@5": 91.420, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ResNet50_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnet50-0676ba61.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 25557032, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", + "acc@1": 76.130, + "acc@5": 92.862, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/resnet50-11ad3fa6.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 25557032, + "recipe": "https://github.com/pytorch/vision/issues/3995#issuecomment-1013906621", + "acc@1": 80.858, + "acc@5": 95.434, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class ResNet101_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnet101-63fe2227.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 44549160, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", + "acc@1": 77.374, + "acc@5": 93.546, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/resnet101-cd907fc2.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 44549160, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 81.886, + "acc@5": 95.780, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class ResNet152_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnet152-394f9c45.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 60192808, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", + "acc@1": 78.312, + "acc@5": 94.046, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/resnet152-f82ba261.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 60192808, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 82.284, + "acc@5": 96.002, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class ResNeXt50_32X4D_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNeXt", + "publication_year": 2016, + "num_params": 25028904, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext", + "acc@1": 77.618, + "acc@5": 93.698, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "architecture": "ResNeXt", + "publication_year": 2016, + "num_params": 25028904, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 81.198, + "acc@5": 95.340, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class ResNeXt101_32X8D_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNeXt", + "publication_year": 2016, + "num_params": 88791336, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext", + "acc@1": 79.312, + "acc@5": 94.526, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "architecture": "ResNeXt", + "publication_year": 2016, + "num_params": 88791336, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", + "acc@1": 82.834, + "acc@5": 96.228, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class Wide_ResNet50_2_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "WideResNet", + "publication_year": 2016, + "num_params": 68883240, + "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439", + "acc@1": 78.468, + "acc@5": 94.086, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "architecture": "WideResNet", + "publication_year": 2016, + "num_params": 68883240, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", + "acc@1": 81.602, + "acc@5": 95.758, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class Wide_ResNet101_2_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "WideResNet", + "publication_year": 2016, + "num_params": 126886696, + "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439", + "acc@1": 78.848, + "acc@5": 94.284, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "architecture": "WideResNet", + "publication_year": 2016, + "num_params": 126886696, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 82.510, + "acc@5": 96.020, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +@handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1)) +def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ResNet18_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) + weights = ResNet18_Weights.verify(weights) + + return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs) -def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: +@handle_legacy_interface(weights=("pretrained", ResNet34_Weights.IMAGENET1K_V1)) +def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-34 model from `"Deep Residual Learning for Image Recognition" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ResNet34_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) + weights = ResNet34_Weights.verify(weights) + return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs) -def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + +@handle_legacy_interface(weights=("pretrained", ResNet50_Weights.IMAGENET1K_V1)) +def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-50 model from `"Deep Residual Learning for Image Recognition" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ResNet50_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) + weights = ResNet50_Weights.verify(weights) + + return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) -def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: +@handle_legacy_interface(weights=("pretrained", ResNet101_Weights.IMAGENET1K_V1)) +def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-101 model from `"Deep Residual Learning for Image Recognition" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ResNet101_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet("resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) + weights = ResNet101_Weights.verify(weights) + return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) -def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + +@handle_legacy_interface(weights=("pretrained", ResNet152_Weights.IMAGENET1K_V1)) +def resnet152(*, weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-152 model from `"Deep Residual Learning for Image Recognition" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ResNet152_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet("resnet152", Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs) + weights = ResNet152_Weights.verify(weights) + + return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs) -def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: +@handle_legacy_interface(weights=("pretrained", ResNeXt50_32X4D_Weights.IMAGENET1K_V1)) +def resnext50_32x4d( + *, weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: bool = True, **kwargs: Any +) -> ResNet: r"""ResNeXt-50 32x4d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ResNeXt50_32X4D_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - kwargs["groups"] = 32 - kwargs["width_per_group"] = 4 - return _resnet("resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) + weights = ResNeXt50_32X4D_Weights.verify(weights) + _ovewrite_named_param(kwargs, "groups", 32) + _ovewrite_named_param(kwargs, "width_per_group", 4) + return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) -def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + +@handle_legacy_interface(weights=("pretrained", ResNeXt101_32X8D_Weights.IMAGENET1K_V1)) +def resnext101_32x8d( + *, weights: Optional[ResNeXt101_32X8D_Weights] = None, progress: bool = True, **kwargs: Any +) -> ResNet: r"""ResNeXt-101 32x8d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ResNeXt101_32X8D_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - kwargs["groups"] = 32 - kwargs["width_per_group"] = 8 - return _resnet("resnext101_32x8d", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) + weights = ResNeXt101_32X8D_Weights.verify(weights) + _ovewrite_named_param(kwargs, "groups", 32) + _ovewrite_named_param(kwargs, "width_per_group", 8) + return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) -def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + +@handle_legacy_interface(weights=("pretrained", Wide_ResNet50_2_Weights.IMAGENET1K_V1)) +def wide_resnet50_2( + *, weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any +) -> ResNet: r"""Wide ResNet-50-2 model from `"Wide Residual Networks" `_. @@ -389,14 +673,19 @@ def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: A channels, and in Wide ResNet-50-2 has 2048-1024-2048. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (Wide_ResNet50_2_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - kwargs["width_per_group"] = 64 * 2 - return _resnet("wide_resnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) + weights = Wide_ResNet50_2_Weights.verify(weights) + + _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) + return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) -def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: +@handle_legacy_interface(weights=("pretrained", Wide_ResNet101_2_Weights.IMAGENET1K_V1)) +def wide_resnet101_2( + *, weights: Optional[Wide_ResNet101_2_Weights] = None, progress: bool = True, **kwargs: Any +) -> ResNet: r"""Wide ResNet-101-2 model from `"Wide Residual Networks" `_. @@ -406,8 +695,10 @@ def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: channels, and in Wide ResNet-50-2 has 2048-1024-2048. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (Wide_ResNet101_2_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - kwargs["width_per_group"] = 64 * 2 - return _resnet("wide_resnet101_2", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) + weights = Wide_ResNet101_2_Weights.verify(weights) + + _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) + return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index f3758c54aaf..e196dedcc3e 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -1,21 +1,28 @@ -from typing import Callable, Any, List +from functools import partial +from typing import Callable, Any, List, Optional import torch import torch.nn as nn from torch import Tensor -from .._internally_replaced_utils import load_state_dict_from_url +from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param -__all__ = ["ShuffleNetV2", "shufflenet_v2_x0_5", "shufflenet_v2_x1_0", "shufflenet_v2_x1_5", "shufflenet_v2_x2_0"] - -model_urls = { - "shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", - "shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", - "shufflenetv2_x1.5": None, - "shufflenetv2_x2.0": None, -} +__all__ = [ + "ShuffleNetV2", + "ShuffleNet_V2_X0_5_Weights", + "ShuffleNet_V2_X1_0_Weights", + "ShuffleNet_V2_X1_5_Weights", + "ShuffleNet_V2_X2_0_Weights", + "shufflenet_v2_x0_5", + "shufflenet_v2_x1_0", + "shufflenet_v2_x1_5", + "shufflenet_v2_x2_0", +] def channel_shuffle(x: Tensor, groups: int) -> Tensor: @@ -156,67 +163,138 @@ def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) -def _shufflenetv2(arch: str, pretrained: bool, progress: bool, *args: Any, **kwargs: Any) -> ShuffleNetV2: +def _shufflenetv2( + weights: Optional[WeightsEnum], + progress: bool, + *args: Any, + **kwargs: Any, +) -> ShuffleNetV2: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = ShuffleNetV2(*args, **kwargs) - if pretrained: - model_url = model_urls[arch] - if model_url is None: - raise ValueError(f"No checkpoint is available for model type {arch}") - else: - state_dict = load_state_dict_from_url(model_url, progress=progress) - model.load_state_dict(state_dict) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) return model -def shufflenet_v2_x0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: +_COMMON_META = { + "task": "image_classification", + "architecture": "ShuffleNetV2", + "publication_year": 2018, + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/barrh/Shufflenet-v2-Pytorch/tree/v0.1.0", +} + + +class ShuffleNet_V2_X0_5_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 1366792, + "acc@1": 69.362, + "acc@5": 88.316, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ShuffleNet_V2_X1_0_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 2278604, + "acc@1": 60.552, + "acc@5": 81.746, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ShuffleNet_V2_X1_5_Weights(WeightsEnum): + pass + + +class ShuffleNet_V2_X2_0_Weights(WeightsEnum): + pass + + +@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1)) +def shufflenet_v2_x0_5( + *, weights: Optional[ShuffleNet_V2_X0_5_Weights] = None, progress: bool = True, **kwargs: Any +) -> ShuffleNetV2: """ Constructs a ShuffleNetV2 with 0.5x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ShuffleNet_V2_X0_5_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _shufflenetv2("shufflenetv2_x0.5", pretrained, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) + weights = ShuffleNet_V2_X0_5_Weights.verify(weights) + return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) -def shufflenet_v2_x1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: + +@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1)) +def shufflenet_v2_x1_0( + *, weights: Optional[ShuffleNet_V2_X1_0_Weights] = None, progress: bool = True, **kwargs: Any +) -> ShuffleNetV2: """ Constructs a ShuffleNetV2 with 1.0x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ShuffleNet_V2_X1_0_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _shufflenetv2("shufflenetv2_x1.0", pretrained, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) + weights = ShuffleNet_V2_X1_0_Weights.verify(weights) + + return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) -def shufflenet_v2_x1_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: +@handle_legacy_interface(weights=("pretrained", None)) +def shufflenet_v2_x1_5( + *, weights: Optional[ShuffleNet_V2_X1_5_Weights] = None, progress: bool = True, **kwargs: Any +) -> ShuffleNetV2: """ Constructs a ShuffleNetV2 with 1.5x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ShuffleNet_V2_X1_5_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _shufflenetv2("shufflenetv2_x1.5", pretrained, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) + weights = ShuffleNet_V2_X1_5_Weights.verify(weights) + return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) -def shufflenet_v2_x2_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: + +@handle_legacy_interface(weights=("pretrained", None)) +def shufflenet_v2_x2_0( + *, weights: Optional[ShuffleNet_V2_X2_0_Weights] = None, progress: bool = True, **kwargs: Any +) -> ShuffleNetV2: """ Constructs a ShuffleNetV2 with 2.0x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ShuffleNet_V2_X2_0_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _shufflenetv2("shufflenetv2_x2.0", pretrained, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) + weights = ShuffleNet_V2_X2_0_Weights.verify(weights) + + return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) diff --git a/torchvision/models/squeezenet.py b/torchvision/models/squeezenet.py index 2c1a30f225d..d495b3148e5 100644 --- a/torchvision/models/squeezenet.py +++ b/torchvision/models/squeezenet.py @@ -1,18 +1,18 @@ -from typing import Any +from functools import partial +from typing import Any, Optional import torch import torch.nn as nn import torch.nn.init as init -from .._internally_replaced_utils import load_state_dict_from_url +from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param -__all__ = ["SqueezeNet", "squeezenet1_0", "squeezenet1_1"] -model_urls = { - "squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", - "squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", -} +__all__ = ["SqueezeNet", "SqueezeNet1_0_Weights", "SqueezeNet1_1_Weights", "squeezenet1_0", "squeezenet1_1"] class Fire(nn.Module): @@ -97,29 +97,85 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.flatten(x, 1) -def _squeezenet(version: str, pretrained: bool, progress: bool, **kwargs: Any) -> SqueezeNet: +def _squeezenet( + version: str, + weights: Optional[WeightsEnum], + progress: bool, + **kwargs: Any, +) -> SqueezeNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = SqueezeNet(version, **kwargs) - if pretrained: - arch = "squeezenet" + version - state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model -def squeezenet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet: +_COMMON_META = { + "task": "image_classification", + "architecture": "SqueezeNet", + "publication_year": 2016, + "size": (224, 224), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/pull/49#issuecomment-277560717", +} + + +class SqueezeNet1_0_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "min_size": (21, 21), + "num_params": 1248424, + "acc@1": 58.092, + "acc@5": 80.420, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class SqueezeNet1_1_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "min_size": (17, 17), + "num_params": 1235496, + "acc@1": 58.178, + "acc@5": 80.624, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", SqueezeNet1_0_Weights.IMAGENET1K_V1)) +def squeezenet1_0( + *, weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any +) -> SqueezeNet: r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and <0.5MB model size" `_ paper. The required minimum input size of the model is 21x21. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (SqueezeNet1_0_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _squeezenet("1_0", pretrained, progress, **kwargs) + weights = SqueezeNet1_0_Weights.verify(weights) + return _squeezenet("1_0", weights, progress, **kwargs) -def squeezenet1_1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet: +@handle_legacy_interface(weights=("pretrained", SqueezeNet1_1_Weights.IMAGENET1K_V1)) +def squeezenet1_1( + *, weights: Optional[SqueezeNet1_1_Weights] = None, progress: bool = True, **kwargs: Any +) -> SqueezeNet: r"""SqueezeNet 1.1 model from the `official SqueezeNet repo `_. SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters @@ -127,7 +183,8 @@ def squeezenet1_1(pretrained: bool = False, progress: bool = True, **kwargs: Any The required minimum input size of the model is 17x17. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (SqueezeNet1_1_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _squeezenet("1_1", pretrained, progress, **kwargs) + weights = SqueezeNet1_1_Weights.verify(weights) + return _squeezenet("1_1", weights, progress, **kwargs) diff --git a/torchvision/models/vgg.py b/torchvision/models/vgg.py index 07639017a31..5393827b293 100644 --- a/torchvision/models/vgg.py +++ b/torchvision/models/vgg.py @@ -1,37 +1,37 @@ -from typing import Union, List, Dict, Any, cast +from functools import partial +from typing import Union, List, Dict, Any, Optional, cast import torch import torch.nn as nn -from .._internally_replaced_utils import load_state_dict_from_url +from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param __all__ = [ "VGG", + "VGG11_Weights", + "VGG11_BN_Weights", + "VGG13_Weights", + "VGG13_BN_Weights", + "VGG16_Weights", + "VGG16_BN_Weights", + "VGG19_Weights", + "VGG19_BN_Weights", "vgg11", "vgg11_bn", "vgg13", "vgg13_bn", "vgg16", "vgg16_bn", - "vgg19_bn", "vgg19", + "vgg19_bn", ] -model_urls = { - "vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth", - "vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth", - "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth", - "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", - "vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth", - "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", - "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", - "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", -} - - class VGG(nn.Module): def __init__( self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5 @@ -95,107 +95,274 @@ def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequ } -def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG: - if pretrained: - kwargs["init_weights"] = False +def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> VGG: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) - model.load_state_dict(state_dict) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) return model -def vgg11(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: +_COMMON_META = { + "task": "image_classification", + "architecture": "VGG", + "publication_year": 2014, + "size": (224, 224), + "min_size": (32, 32), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg", +} + + +class VGG11_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg11-8a719046.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 132863336, + "acc@1": 69.020, + "acc@5": 88.628, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class VGG11_BN_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 132868840, + "acc@1": 70.370, + "acc@5": 89.810, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class VGG13_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg13-19584684.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 133047848, + "acc@1": 69.928, + "acc@5": 89.246, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class VGG13_BN_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 133053736, + "acc@1": 71.586, + "acc@5": 90.374, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class VGG16_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg16-397923af.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 138357544, + "acc@1": 71.592, + "acc@5": 90.382, + }, + ) + # We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the + # same input standardization method as the paper. Only the `features` weights have proper values, those on the + # `classifier` module are filled with nans. + IMAGENET1K_FEATURES = Weights( + url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth", + transforms=partial( + ImageClassificationEval, + crop_size=224, + mean=(0.48235, 0.45882, 0.40784), + std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0), + ), + meta={ + **_COMMON_META, + "num_params": 138357544, + "categories": None, + "recipe": "https://github.com/amdegroot/ssd.pytorch#training-ssd", + "acc@1": float("nan"), + "acc@5": float("nan"), + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class VGG16_BN_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 138365992, + "acc@1": 73.360, + "acc@5": 91.516, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class VGG19_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 143667240, + "acc@1": 72.376, + "acc@5": 90.876, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class VGG19_BN_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 143678248, + "acc@1": 74.218, + "acc@5": 91.842, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", VGG11_Weights.IMAGENET1K_V1)) +def vgg11(*, weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: r"""VGG 11-layer model (configuration "A") from `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. The required minimum input size of the model is 32x32. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (VGG11_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg("vgg11", "A", False, pretrained, progress, **kwargs) + weights = VGG11_Weights.verify(weights) + return _vgg("A", False, weights, progress, **kwargs) -def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: + +@handle_legacy_interface(weights=("pretrained", VGG11_BN_Weights.IMAGENET1K_V1)) +def vgg11_bn(*, weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: r"""VGG 11-layer model (configuration "A") with batch normalization `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. The required minimum input size of the model is 32x32. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (VGG11_BN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg("vgg11_bn", "A", True, pretrained, progress, **kwargs) + weights = VGG11_BN_Weights.verify(weights) + + return _vgg("A", True, weights, progress, **kwargs) -def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: +@handle_legacy_interface(weights=("pretrained", VGG13_Weights.IMAGENET1K_V1)) +def vgg13(*, weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: r"""VGG 13-layer model (configuration "B") `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. The required minimum input size of the model is 32x32. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (VGG13_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg("vgg13", "B", False, pretrained, progress, **kwargs) + weights = VGG13_Weights.verify(weights) + return _vgg("B", False, weights, progress, **kwargs) -def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: + +@handle_legacy_interface(weights=("pretrained", VGG13_BN_Weights.IMAGENET1K_V1)) +def vgg13_bn(*, weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: r"""VGG 13-layer model (configuration "B") with batch normalization `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. The required minimum input size of the model is 32x32. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (VGG13_BN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg("vgg13_bn", "B", True, pretrained, progress, **kwargs) + weights = VGG13_BN_Weights.verify(weights) + + return _vgg("B", True, weights, progress, **kwargs) -def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: +@handle_legacy_interface(weights=("pretrained", VGG16_Weights.IMAGENET1K_V1)) +def vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: r"""VGG 16-layer model (configuration "D") `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. The required minimum input size of the model is 32x32. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (VGG16_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg("vgg16", "D", False, pretrained, progress, **kwargs) + weights = VGG16_Weights.verify(weights) + return _vgg("D", False, weights, progress, **kwargs) -def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: + +@handle_legacy_interface(weights=("pretrained", VGG16_BN_Weights.IMAGENET1K_V1)) +def vgg16_bn(*, weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: r"""VGG 16-layer model (configuration "D") with batch normalization `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. The required minimum input size of the model is 32x32. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (VGG16_BN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg("vgg16_bn", "D", True, pretrained, progress, **kwargs) + weights = VGG16_BN_Weights.verify(weights) + + return _vgg("D", True, weights, progress, **kwargs) -def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: +@handle_legacy_interface(weights=("pretrained", VGG19_Weights.IMAGENET1K_V1)) +def vgg19(*, weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: r"""VGG 19-layer model (configuration "E") `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. The required minimum input size of the model is 32x32. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (VGG19_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg("vgg19", "E", False, pretrained, progress, **kwargs) + weights = VGG19_Weights.verify(weights) + return _vgg("E", False, weights, progress, **kwargs) -def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: + +@handle_legacy_interface(weights=("pretrained", VGG19_BN_Weights.IMAGENET1K_V1)) +def vgg19_bn(*, weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: r"""VGG 19-layer model (configuration 'E') with batch normalization `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. The required minimum input size of the model is 32x32. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (VGG19_BN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg("vgg19_bn", "E", True, pretrained, progress, **kwargs) + weights = VGG19_BN_Weights.verify(weights) + + return _vgg("E", True, weights, progress, **kwargs) diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 29f756ccbe5..b11b7377ed1 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -6,25 +6,26 @@ import torch import torch.nn as nn -from .._internally_replaced_utils import load_state_dict_from_url from ..ops.misc import Conv2dNormActivation +from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param + __all__ = [ "VisionTransformer", + "ViT_B_16_Weights", + "ViT_B_32_Weights", + "ViT_L_16_Weights", + "ViT_L_32_Weights", "vit_b_16", "vit_b_32", "vit_l_16", "vit_l_32", ] -model_urls = { - "vit_b_16": "https://download.pytorch.org/models/vit_b_16-c867db91.pth", - "vit_b_32": "https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", - "vit_l_16": "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", - "vit_l_32": "https://download.pytorch.org/models/vit_l_32-c7638314.pth", -} - class ConvStemConfig(NamedTuple): out_channels: int @@ -274,18 +275,20 @@ def forward(self, x: torch.Tensor): def _vision_transformer( - arch: str, patch_size: int, num_layers: int, num_heads: int, hidden_dim: int, mlp_dim: int, - pretrained: bool, + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> VisionTransformer: image_size = kwargs.pop("image_size", 224) + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = VisionTransformer( image_size=image_size, patch_size=patch_size, @@ -296,98 +299,180 @@ def _vision_transformer( **kwargs, ) - if pretrained: - if arch not in model_urls: - raise ValueError(f"No checkpoint is available for model type '{arch}'!") - state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) - model.load_state_dict(state_dict) + if weights: + model.load_state_dict(weights.get_state_dict(progress=progress)) return model -def vit_b_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: +_COMMON_META = { + "task": "image_classification", + "architecture": "ViT", + "publication_year": 2020, + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, +} + + +class ViT_B_16_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vit_b_16-c867db91.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 86567656, + "size": (224, 224), + "min_size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_16", + "acc@1": 81.072, + "acc@5": 95.318, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ViT_B_32_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 88224232, + "size": (224, 224), + "min_size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_32", + "acc@1": 75.912, + "acc@5": 92.466, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ViT_L_16_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=242), + meta={ + **_COMMON_META, + "num_params": 304326632, + "size": (224, 224), + "min_size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_16", + "acc@1": 79.662, + "acc@5": 94.638, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ViT_L_32_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vit_l_32-c7638314.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 306535400, + "size": (224, 224), + "min_size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_32", + "acc@1": 76.972, + "acc@5": 93.07, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1)) +def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_b_16 architecture from `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ViT_B_16_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = ViT_B_16_Weights.verify(weights) + return _vision_transformer( - arch="vit_b_16", patch_size=16, num_layers=12, num_heads=12, hidden_dim=768, mlp_dim=3072, - pretrained=pretrained, + weights=weights, progress=progress, **kwargs, ) -def vit_b_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: +@handle_legacy_interface(weights=("pretrained", ViT_B_32_Weights.IMAGENET1K_V1)) +def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_b_32 architecture from `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ViT_B_32_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = ViT_B_32_Weights.verify(weights) + return _vision_transformer( - arch="vit_b_32", patch_size=32, num_layers=12, num_heads=12, hidden_dim=768, mlp_dim=3072, - pretrained=pretrained, + weights=weights, progress=progress, **kwargs, ) -def vit_l_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: +@handle_legacy_interface(weights=("pretrained", ViT_L_16_Weights.IMAGENET1K_V1)) +def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_l_16 architecture from `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ViT_L_16_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = ViT_L_16_Weights.verify(weights) + return _vision_transformer( - arch="vit_l_16", patch_size=16, num_layers=24, num_heads=16, hidden_dim=1024, mlp_dim=4096, - pretrained=pretrained, + weights=weights, progress=progress, **kwargs, ) -def vit_l_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: +@handle_legacy_interface(weights=("pretrained", ViT_L_32_Weights.IMAGENET1K_V1)) +def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_l_32 architecture from `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ViT_L_32_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = ViT_L_32_Weights.verify(weights) + return _vision_transformer( - arch="vit_l_32", patch_size=32, num_layers=24, num_heads=16, hidden_dim=1024, mlp_dim=4096, - pretrained=pretrained, + weights=weights, progress=progress, **kwargs, ) diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py index 83e49908348..5988c160aad 100644 --- a/torchvision/prototype/models/__init__.py +++ b/torchvision/prototype/models/__init__.py @@ -1,20 +1,4 @@ -from .alexnet import * -from .convnext import * -from .densenet import * -from .efficientnet import * -from .googlenet import * -from .inception import * -from .mnasnet import * -from .mobilenet import * -from .regnet import * -from .resnet import * -from .shufflenetv2 import * -from .squeezenet import * -from .vgg import * -from .vision_transformer import * from . import detection from . import optical_flow -from . import quantization from . import segmentation from . import video -from ._api import get_weight diff --git a/torchvision/prototype/models/_utils.py b/torchvision/prototype/models/_utils.py deleted file mode 100644 index cc9f7dcfc36..00000000000 --- a/torchvision/prototype/models/_utils.py +++ /dev/null @@ -1,108 +0,0 @@ -import functools -import warnings -from typing import Any, Dict, Optional, TypeVar, Callable, Tuple, Union - -from torch import nn -from torchvision.prototype.utils._internal import kwonly_to_pos_or_kw - -from ._api import WeightsEnum - -W = TypeVar("W", bound=WeightsEnum) -M = TypeVar("M", bound=nn.Module) -V = TypeVar("V") - - -def handle_legacy_interface(**weights: Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): - """Decorates a model builder with the new interface to make it compatible with the old. - - In particular this handles two things: - - 1. Allows positional parameters again, but emits a deprecation warning in case they are used. See - :func:`torchvision.prototype.utils._internal.kwonly_to_pos_or_kw` for details. - 2. Handles the default value change from ``pretrained=False`` to ``weights=None`` and ``pretrained=True`` to - ``weights=Weights`` and emits a deprecation warning with instructions for the new interface. - - Args: - **weights (Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): Deprecated parameter - name and default value for the legacy ``pretrained=True``. The default value can be a callable in which - case it will be called with a dictionary of the keyword arguments. The only key that is guaranteed to be in - the dictionary is the deprecated parameter name passed as first element in the tuple. All other parameters - should be accessed with :meth:`~dict.get`. - """ - - def outer_wrapper(builder: Callable[..., M]) -> Callable[..., M]: - @kwonly_to_pos_or_kw - @functools.wraps(builder) - def inner_wrapper(*args: Any, **kwargs: Any) -> M: - for weights_param, (pretrained_param, default) in weights.items(): # type: ignore[union-attr] - # If neither the weights nor the pretrained parameter as passed, or the weights argument already use - # the new style arguments, there is nothing to do. Note that we cannot use `None` as sentinel for the - # weight argument, since it is a valid value. - sentinel = object() - weights_arg = kwargs.get(weights_param, sentinel) - if ( - (weights_param not in kwargs and pretrained_param not in kwargs) - or isinstance(weights_arg, WeightsEnum) - or (isinstance(weights_arg, str) and weights_arg != "legacy") - or weights_arg is None - ): - continue - - # If the pretrained parameter was passed as positional argument, it is now mapped to - # `kwargs[weights_param]`. This happens because the @kwonly_to_pos_or_kw decorator uses the current - # signature to infer the names of positionally passed arguments and thus has no knowledge that there - # used to be a pretrained parameter. - pretrained_positional = weights_arg is not sentinel - if pretrained_positional: - # We put the pretrained argument under its legacy name in the keyword argument dictionary to have a - # unified access to the value if the default value is a callable. - kwargs[pretrained_param] = pretrained_arg = kwargs.pop(weights_param) - else: - pretrained_arg = kwargs[pretrained_param] - - if pretrained_arg: - default_weights_arg = default(kwargs) if callable(default) else default - if not isinstance(default_weights_arg, WeightsEnum): - raise ValueError(f"No weights available for model {builder.__name__}") - else: - default_weights_arg = None - - if not pretrained_positional: - warnings.warn( - f"The parameter '{pretrained_param}' is deprecated, please use '{weights_param}' instead." - ) - - msg = ( - f"Arguments other than a weight enum or `None` for '{weights_param}' are deprecated. " - f"The current behavior is equivalent to passing `{weights_param}={default_weights_arg}`." - ) - if pretrained_arg: - msg = ( - f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.DEFAULT` " - f"to get the most up-to-date weights." - ) - warnings.warn(msg) - - del kwargs[pretrained_param] - kwargs[weights_param] = default_weights_arg - - return builder(*args, **kwargs) - - return inner_wrapper - - return outer_wrapper - - -def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> None: - if param in kwargs: - if kwargs[param] != new_value: - raise ValueError(f"The parameter '{param}' expected value {new_value} but got {kwargs[param]} instead.") - else: - kwargs[param] = new_value - - -def _ovewrite_value_param(param: Optional[V], new_value: V) -> V: - if param is not None: - if param != new_value: - raise ValueError(f"The parameter '{param}' expected value {new_value} but got {param} instead.") - return new_value diff --git a/torchvision/prototype/models/alexnet.py b/torchvision/prototype/models/alexnet.py deleted file mode 100644 index 204a68236d3..00000000000 --- a/torchvision/prototype/models/alexnet.py +++ /dev/null @@ -1,49 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.alexnet import AlexNet -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = ["AlexNet", "AlexNet_Weights", "alexnet"] - - -class AlexNet_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - "task": "image_classification", - "architecture": "AlexNet", - "publication_year": 2012, - "num_params": 61100840, - "size": (224, 224), - "min_size": (63, 63), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg", - "acc@1": 56.522, - "acc@5": 79.066, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1)) -def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet: - weights = AlexNet_Weights.verify(weights) - - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = AlexNet(**kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model diff --git a/torchvision/prototype/models/convnext.py b/torchvision/prototype/models/convnext.py deleted file mode 100644 index 7d63ee155db..00000000000 --- a/torchvision/prototype/models/convnext.py +++ /dev/null @@ -1,169 +0,0 @@ -from functools import partial -from typing import Any, List, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.convnext import ConvNeXt, CNBlockConfig -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = [ - "ConvNeXt", - "ConvNeXt_Tiny_Weights", - "ConvNeXt_Small_Weights", - "ConvNeXt_Base_Weights", - "ConvNeXt_Large_Weights", - "convnext_tiny", - "convnext_small", - "convnext_base", - "convnext_large", -] - - -def _convnext( - block_setting: List[CNBlockConfig], - stochastic_depth_prob: float, - weights: Optional[WeightsEnum], - progress: bool, - **kwargs: Any, -) -> ConvNeXt: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -_COMMON_META = { - "task": "image_classification", - "architecture": "ConvNeXt", - "publication_year": 2022, - "size": (224, 224), - "min_size": (32, 32), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext", -} - - -class ConvNeXt_Tiny_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=236), - meta={ - **_COMMON_META, - "num_params": 28589128, - "acc@1": 82.520, - "acc@5": 96.146, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ConvNeXt_Small_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/convnext_small-0c510722.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=230), - meta={ - **_COMMON_META, - "num_params": 50223688, - "acc@1": 83.616, - "acc@5": 96.650, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ConvNeXt_Base_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/convnext_base-6075fbad.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 88591464, - "acc@1": 84.062, - "acc@5": 96.870, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ConvNeXt_Large_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/convnext_large-ea097f82.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 197767336, - "acc@1": 84.414, - "acc@5": 96.976, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1)) -def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: - weights = ConvNeXt_Tiny_Weights.verify(weights) - - block_setting = [ - CNBlockConfig(96, 192, 3), - CNBlockConfig(192, 384, 3), - CNBlockConfig(384, 768, 9), - CNBlockConfig(768, None, 3), - ] - stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1) - return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ConvNeXt_Small_Weights.IMAGENET1K_V1)) -def convnext_small( - *, weights: Optional[ConvNeXt_Small_Weights] = None, progress: bool = True, **kwargs: Any -) -> ConvNeXt: - weights = ConvNeXt_Small_Weights.verify(weights) - - block_setting = [ - CNBlockConfig(96, 192, 3), - CNBlockConfig(192, 384, 3), - CNBlockConfig(384, 768, 27), - CNBlockConfig(768, None, 3), - ] - stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.4) - return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ConvNeXt_Base_Weights.IMAGENET1K_V1)) -def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: - weights = ConvNeXt_Base_Weights.verify(weights) - - block_setting = [ - CNBlockConfig(128, 256, 3), - CNBlockConfig(256, 512, 3), - CNBlockConfig(512, 1024, 27), - CNBlockConfig(1024, None, 3), - ] - stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5) - return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1)) -def convnext_large( - *, weights: Optional[ConvNeXt_Large_Weights] = None, progress: bool = True, **kwargs: Any -) -> ConvNeXt: - weights = ConvNeXt_Large_Weights.verify(weights) - - block_setting = [ - CNBlockConfig(192, 384, 3), - CNBlockConfig(384, 768, 3), - CNBlockConfig(768, 1536, 27), - CNBlockConfig(1536, None, 3), - ] - stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5) - return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/densenet.py b/torchvision/prototype/models/densenet.py deleted file mode 100644 index 4ad9be028e5..00000000000 --- a/torchvision/prototype/models/densenet.py +++ /dev/null @@ -1,159 +0,0 @@ -import re -from functools import partial -from typing import Any, Optional, Tuple - -import torch.nn as nn -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.densenet import DenseNet -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = [ - "DenseNet", - "DenseNet121_Weights", - "DenseNet161_Weights", - "DenseNet169_Weights", - "DenseNet201_Weights", - "densenet121", - "densenet161", - "densenet169", - "densenet201", -] - - -def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) -> None: - # '.'s are no longer allowed in module names, but previous _DenseLayer - # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. - # They are also in the checkpoints in model_urls. This pattern is used - # to find such keys. - pattern = re.compile( - r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" - ) - - state_dict = weights.get_state_dict(progress=progress) - for key in list(state_dict.keys()): - res = pattern.match(key) - if res: - new_key = res.group(1) + res.group(2) - state_dict[new_key] = state_dict[key] - del state_dict[key] - model.load_state_dict(state_dict) - - -def _densenet( - growth_rate: int, - block_config: Tuple[int, int, int, int], - num_init_features: int, - weights: Optional[WeightsEnum], - progress: bool, - **kwargs: Any, -) -> DenseNet: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) - - if weights is not None: - _load_state_dict(model=model, weights=weights, progress=progress) - - return model - - -_COMMON_META = { - "task": "image_classification", - "architecture": "DenseNet", - "publication_year": 2016, - "size": (224, 224), - "min_size": (29, 29), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/pull/116", -} - - -class DenseNet121_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/densenet121-a639ec97.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 7978856, - "acc@1": 74.434, - "acc@5": 91.972, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class DenseNet161_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/densenet161-8d451a50.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 28681000, - "acc@1": 77.138, - "acc@5": 93.560, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class DenseNet169_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/densenet169-b2777c0a.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 14149480, - "acc@1": 75.600, - "acc@5": 92.806, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class DenseNet201_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/densenet201-c1103571.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 20013928, - "acc@1": 76.896, - "acc@5": 93.370, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.IMAGENET1K_V1)) -def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: - weights = DenseNet121_Weights.verify(weights) - - return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.IMAGENET1K_V1)) -def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: - weights = DenseNet161_Weights.verify(weights) - - return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.IMAGENET1K_V1)) -def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: - weights = DenseNet169_Weights.verify(weights) - - return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.IMAGENET1K_V1)) -def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: - weights = DenseNet201_Weights.verify(weights) - - return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/detection/faster_rcnn.py b/torchvision/prototype/models/detection/faster_rcnn.py index ecdd9bdb423..5abc0eef1c4 100644 --- a/torchvision/prototype/models/detection/faster_rcnn.py +++ b/torchvision/prototype/models/detection/faster_rcnn.py @@ -1,10 +1,10 @@ from typing import Any, Optional, Union from torch import nn -from torchvision.prototype.transforms import ObjectDetectionEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.detection.faster_rcnn import ( +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _COCO_CATEGORIES +from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param +from torchvision.models.detection.faster_rcnn import ( _mobilenet_extractor, _resnet_fpn_extractor, _validate_trainable_layers, @@ -13,11 +13,9 @@ misc_nn_ops, overwrite_eps, ) -from .._api import WeightsEnum, Weights -from .._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large -from ..resnet import ResNet50_Weights, resnet50 +from torchvision.models.mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large +from torchvision.models.resnet import ResNet50_Weights, resnet50 +from torchvision.transforms import ObjectDetectionEval, InterpolationMode __all__ = [ diff --git a/torchvision/prototype/models/detection/fcos.py b/torchvision/prototype/models/detection/fcos.py index db3a679a62d..930b26e46c8 100644 --- a/torchvision/prototype/models/detection/fcos.py +++ b/torchvision/prototype/models/detection/fcos.py @@ -1,20 +1,18 @@ from typing import Any, Optional from torch import nn -from torchvision.prototype.transforms import ObjectDetectionEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.detection.fcos import ( +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _COCO_CATEGORIES +from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param +from torchvision.models.detection.fcos import ( _resnet_fpn_extractor, _validate_trainable_layers, FCOS, LastLevelP6P7, misc_nn_ops, ) -from .._api import WeightsEnum, Weights -from .._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..resnet import ResNet50_Weights, resnet50 +from torchvision.models.resnet import ResNet50_Weights, resnet50 +from torchvision.transforms import ObjectDetectionEval, InterpolationMode __all__ = [ diff --git a/torchvision/prototype/models/detection/keypoint_rcnn.py b/torchvision/prototype/models/detection/keypoint_rcnn.py index e0b4d7061fa..a7780cc9f63 100644 --- a/torchvision/prototype/models/detection/keypoint_rcnn.py +++ b/torchvision/prototype/models/detection/keypoint_rcnn.py @@ -1,20 +1,18 @@ from typing import Any, Optional from torch import nn -from torchvision.prototype.transforms import ObjectDetectionEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.detection.keypoint_rcnn import ( +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES +from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param +from torchvision.models.detection.keypoint_rcnn import ( _resnet_fpn_extractor, _validate_trainable_layers, KeypointRCNN, misc_nn_ops, overwrite_eps, ) -from .._api import WeightsEnum, Weights -from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..resnet import ResNet50_Weights, resnet50 +from torchvision.models.resnet import ResNet50_Weights, resnet50 +from torchvision.transforms import ObjectDetectionEval, InterpolationMode __all__ = [ diff --git a/torchvision/prototype/models/detection/mask_rcnn.py b/torchvision/prototype/models/detection/mask_rcnn.py index 187bf6912b4..d52ebe61be1 100644 --- a/torchvision/prototype/models/detection/mask_rcnn.py +++ b/torchvision/prototype/models/detection/mask_rcnn.py @@ -1,20 +1,18 @@ from typing import Any, Optional from torch import nn -from torchvision.prototype.transforms import ObjectDetectionEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.detection.mask_rcnn import ( +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _COCO_CATEGORIES +from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param +from torchvision.models.detection.mask_rcnn import ( _resnet_fpn_extractor, _validate_trainable_layers, MaskRCNN, misc_nn_ops, overwrite_eps, ) -from .._api import WeightsEnum, Weights -from .._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..resnet import ResNet50_Weights, resnet50 +from torchvision.models.resnet import ResNet50_Weights, resnet50 +from torchvision.transforms import ObjectDetectionEval, InterpolationMode __all__ = [ diff --git a/torchvision/prototype/models/detection/retinanet.py b/torchvision/prototype/models/detection/retinanet.py index eadd6c635ca..c4249118b70 100644 --- a/torchvision/prototype/models/detection/retinanet.py +++ b/torchvision/prototype/models/detection/retinanet.py @@ -1,10 +1,10 @@ from typing import Any, Optional from torch import nn -from torchvision.prototype.transforms import ObjectDetectionEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.detection.retinanet import ( +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _COCO_CATEGORIES +from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param +from torchvision.models.detection.retinanet import ( _resnet_fpn_extractor, _validate_trainable_layers, RetinaNet, @@ -12,10 +12,8 @@ misc_nn_ops, overwrite_eps, ) -from .._api import WeightsEnum, Weights -from .._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..resnet import ResNet50_Weights, resnet50 +from torchvision.models.resnet import ResNet50_Weights, resnet50 +from torchvision.transforms import ObjectDetectionEval, InterpolationMode __all__ = [ diff --git a/torchvision/prototype/models/detection/ssd.py b/torchvision/prototype/models/detection/ssd.py index 3cab044958d..a3c5b965deb 100644 --- a/torchvision/prototype/models/detection/ssd.py +++ b/torchvision/prototype/models/detection/ssd.py @@ -1,19 +1,17 @@ import warnings from typing import Any, Optional -from torchvision.prototype.transforms import ObjectDetectionEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.detection.ssd import ( +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _COCO_CATEGORIES +from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param +from torchvision.models.detection.ssd import ( _validate_trainable_layers, _vgg_extractor, DefaultBoxGenerator, SSD, ) -from .._api import WeightsEnum, Weights -from .._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..vgg import VGG16_Weights, vgg16 +from torchvision.models.vgg import VGG16_Weights, vgg16 +from torchvision.transforms import ObjectDetectionEval, InterpolationMode __all__ = [ diff --git a/torchvision/prototype/models/detection/ssdlite.py b/torchvision/prototype/models/detection/ssdlite.py index 6de34acb5ae..d9f2ee58bc6 100644 --- a/torchvision/prototype/models/detection/ssdlite.py +++ b/torchvision/prototype/models/detection/ssdlite.py @@ -3,10 +3,10 @@ from typing import Any, Callable, Optional from torch import nn -from torchvision.prototype.transforms import ObjectDetectionEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.detection.ssdlite import ( +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _COCO_CATEGORIES +from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param +from torchvision.models.detection.ssdlite import ( _mobilenet_extractor, _normal_init, _validate_trainable_layers, @@ -15,10 +15,8 @@ SSD, SSDLiteHead, ) -from .._api import WeightsEnum, Weights -from .._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large +from torchvision.models.mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large +from torchvision.transforms import ObjectDetectionEval, InterpolationMode __all__ = [ diff --git a/torchvision/prototype/models/efficientnet.py b/torchvision/prototype/models/efficientnet.py deleted file mode 100644 index cb6d2bb2b35..00000000000 --- a/torchvision/prototype/models/efficientnet.py +++ /dev/null @@ -1,453 +0,0 @@ -from functools import partial -from typing import Any, Optional, Sequence, Union - -from torch import nn -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.efficientnet import EfficientNet, MBConvConfig, FusedMBConvConfig, _efficientnet_conf -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = [ - "EfficientNet", - "EfficientNet_B0_Weights", - "EfficientNet_B1_Weights", - "EfficientNet_B2_Weights", - "EfficientNet_B3_Weights", - "EfficientNet_B4_Weights", - "EfficientNet_B5_Weights", - "EfficientNet_B6_Weights", - "EfficientNet_B7_Weights", - "EfficientNet_V2_S_Weights", - "EfficientNet_V2_M_Weights", - "EfficientNet_V2_L_Weights", - "efficientnet_b0", - "efficientnet_b1", - "efficientnet_b2", - "efficientnet_b3", - "efficientnet_b4", - "efficientnet_b5", - "efficientnet_b6", - "efficientnet_b7", - "efficientnet_v2_s", - "efficientnet_v2_m", - "efficientnet_v2_l", -] - - -def _efficientnet( - inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]], - dropout: float, - last_channel: Optional[int], - weights: Optional[WeightsEnum], - progress: bool, - **kwargs: Any, -) -> EfficientNet: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -_COMMON_META = { - "task": "image_classification", - "categories": _IMAGENET_CATEGORIES, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet", -} - - -_COMMON_META_V1 = { - **_COMMON_META, - "architecture": "EfficientNet", - "publication_year": 2019, - "interpolation": InterpolationMode.BICUBIC, - "min_size": (1, 1), -} - - -_COMMON_META_V2 = { - **_COMMON_META, - "architecture": "EfficientNetV2", - "publication_year": 2021, - "interpolation": InterpolationMode.BILINEAR, - "min_size": (33, 33), -} - - -class EfficientNet_B0_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", - transforms=partial( - ImageClassificationEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC - ), - meta={ - **_COMMON_META_V1, - "num_params": 5288548, - "size": (224, 224), - "acc@1": 77.692, - "acc@5": 93.532, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class EfficientNet_B1_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", - transforms=partial( - ImageClassificationEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC - ), - meta={ - **_COMMON_META_V1, - "num_params": 7794184, - "size": (240, 240), - "acc@1": 78.642, - "acc@5": 94.186, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth", - transforms=partial( - ImageClassificationEval, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR - ), - meta={ - **_COMMON_META_V1, - "num_params": 7794184, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-lr-wd-crop-tuning", - "interpolation": InterpolationMode.BILINEAR, - "size": (240, 240), - "acc@1": 79.838, - "acc@5": 94.934, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class EfficientNet_B2_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", - transforms=partial( - ImageClassificationEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC - ), - meta={ - **_COMMON_META_V1, - "num_params": 9109994, - "size": (288, 288), - "acc@1": 80.608, - "acc@5": 95.310, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class EfficientNet_B3_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", - transforms=partial( - ImageClassificationEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC - ), - meta={ - **_COMMON_META_V1, - "num_params": 12233232, - "size": (300, 300), - "acc@1": 82.008, - "acc@5": 96.054, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class EfficientNet_B4_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", - transforms=partial( - ImageClassificationEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC - ), - meta={ - **_COMMON_META_V1, - "num_params": 19341616, - "size": (380, 380), - "acc@1": 83.384, - "acc@5": 96.594, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class EfficientNet_B5_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", - transforms=partial( - ImageClassificationEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC - ), - meta={ - **_COMMON_META_V1, - "num_params": 30389784, - "size": (456, 456), - "acc@1": 83.444, - "acc@5": 96.628, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class EfficientNet_B6_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", - transforms=partial( - ImageClassificationEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC - ), - meta={ - **_COMMON_META_V1, - "num_params": 43040704, - "size": (528, 528), - "acc@1": 84.008, - "acc@5": 96.916, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class EfficientNet_B7_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", - transforms=partial( - ImageClassificationEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC - ), - meta={ - **_COMMON_META_V1, - "num_params": 66347960, - "size": (600, 600), - "acc@1": 84.122, - "acc@5": 96.908, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class EfficientNet_V2_S_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth", - transforms=partial( - ImageClassificationEval, - crop_size=384, - resize_size=384, - interpolation=InterpolationMode.BILINEAR, - ), - meta={ - **_COMMON_META_V2, - "num_params": 21458488, - "size": (384, 384), - "acc@1": 84.228, - "acc@5": 96.878, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class EfficientNet_V2_M_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth", - transforms=partial( - ImageClassificationEval, - crop_size=480, - resize_size=480, - interpolation=InterpolationMode.BILINEAR, - ), - meta={ - **_COMMON_META_V2, - "num_params": 54139356, - "size": (480, 480), - "acc@1": 85.112, - "acc@5": 97.156, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class EfficientNet_V2_L_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth", - transforms=partial( - ImageClassificationEval, - crop_size=480, - resize_size=480, - interpolation=InterpolationMode.BICUBIC, - mean=(0.5, 0.5, 0.5), - std=(0.5, 0.5, 0.5), - ), - meta={ - **_COMMON_META_V2, - "num_params": 118515272, - "size": (480, 480), - "acc@1": 85.808, - "acc@5": 97.788, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.IMAGENET1K_V1)) -def efficientnet_b0( - *, weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_B0_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b0", width_mult=1.0, depth_mult=1.0) - return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.IMAGENET1K_V1)) -def efficientnet_b1( - *, weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_B1_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b1", width_mult=1.0, depth_mult=1.1) - return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.IMAGENET1K_V1)) -def efficientnet_b2( - *, weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_B2_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b2", width_mult=1.1, depth_mult=1.2) - return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.IMAGENET1K_V1)) -def efficientnet_b3( - *, weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_B3_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b3", width_mult=1.2, depth_mult=1.4) - return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.IMAGENET1K_V1)) -def efficientnet_b4( - *, weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_B4_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b4", width_mult=1.4, depth_mult=1.8) - return _efficientnet(inverted_residual_setting, 0.4, last_channel, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.IMAGENET1K_V1)) -def efficientnet_b5( - *, weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_B5_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b5", width_mult=1.6, depth_mult=2.2) - return _efficientnet( - inverted_residual_setting, - 0.4, - last_channel, - weights, - progress, - norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_B6_Weights.IMAGENET1K_V1)) -def efficientnet_b6( - *, weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_B6_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b6", width_mult=1.8, depth_mult=2.6) - return _efficientnet( - inverted_residual_setting, - 0.5, - last_channel, - weights, - progress, - norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_B7_Weights.IMAGENET1K_V1)) -def efficientnet_b7( - *, weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_B7_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b7", width_mult=2.0, depth_mult=3.1) - return _efficientnet( - inverted_residual_setting, - 0.5, - last_channel, - weights, - progress, - norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_S_Weights.IMAGENET1K_V1)) -def efficientnet_v2_s( - *, weights: Optional[EfficientNet_V2_S_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_V2_S_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_s") - return _efficientnet( - inverted_residual_setting, - 0.2, - last_channel, - weights, - progress, - norm_layer=partial(nn.BatchNorm2d, eps=1e-03), - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_M_Weights.IMAGENET1K_V1)) -def efficientnet_v2_m( - *, weights: Optional[EfficientNet_V2_M_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_V2_M_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_m") - return _efficientnet( - inverted_residual_setting, - 0.3, - last_channel, - weights, - progress, - norm_layer=partial(nn.BatchNorm2d, eps=1e-03), - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_L_Weights.IMAGENET1K_V1)) -def efficientnet_v2_l( - *, weights: Optional[EfficientNet_V2_L_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_V2_L_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_l") - return _efficientnet( - inverted_residual_setting, - 0.4, - last_channel, - weights, - progress, - norm_layer=partial(nn.BatchNorm2d, eps=1e-03), - **kwargs, - ) diff --git a/torchvision/prototype/models/googlenet.py b/torchvision/prototype/models/googlenet.py deleted file mode 100644 index 70dc0d9db5c..00000000000 --- a/torchvision/prototype/models/googlenet.py +++ /dev/null @@ -1,63 +0,0 @@ -import warnings -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.googlenet import GoogLeNet, GoogLeNetOutputs, _GoogLeNetOutputs -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNet_Weights", "googlenet"] - - -class GoogLeNet_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/googlenet-1378be20.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - "task": "image_classification", - "architecture": "GoogLeNet", - "publication_year": 2014, - "num_params": 6624904, - "size": (224, 224), - "min_size": (15, 15), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#googlenet", - "acc@1": 69.778, - "acc@5": 89.530, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -@handle_legacy_interface(weights=("pretrained", GoogLeNet_Weights.IMAGENET1K_V1)) -def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet: - weights = GoogLeNet_Weights.verify(weights) - - original_aux_logits = kwargs.get("aux_logits", False) - if weights is not None: - if "transform_input" not in kwargs: - _ovewrite_named_param(kwargs, "transform_input", True) - _ovewrite_named_param(kwargs, "aux_logits", True) - _ovewrite_named_param(kwargs, "init_weights", False) - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = GoogLeNet(**kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - if not original_aux_logits: - model.aux_logits = False - model.aux1 = None # type: ignore[assignment] - model.aux2 = None # type: ignore[assignment] - else: - warnings.warn( - "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them" - ) - - return model diff --git a/torchvision/prototype/models/inception.py b/torchvision/prototype/models/inception.py deleted file mode 100644 index eec78a26236..00000000000 --- a/torchvision/prototype/models/inception.py +++ /dev/null @@ -1,57 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception_V3_Weights", "inception_v3"] - - -class Inception_V3_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", - transforms=partial(ImageClassificationEval, crop_size=299, resize_size=342), - meta={ - "task": "image_classification", - "architecture": "InceptionV3", - "publication_year": 2015, - "num_params": 27161264, - "size": (299, 299), - "min_size": (75, 75), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#inception-v3", - "acc@1": 77.294, - "acc@5": 93.450, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -@handle_legacy_interface(weights=("pretrained", Inception_V3_Weights.IMAGENET1K_V1)) -def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3: - weights = Inception_V3_Weights.verify(weights) - - original_aux_logits = kwargs.get("aux_logits", True) - if weights is not None: - if "transform_input" not in kwargs: - _ovewrite_named_param(kwargs, "transform_input", True) - _ovewrite_named_param(kwargs, "aux_logits", True) - _ovewrite_named_param(kwargs, "init_weights", False) - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = Inception3(**kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - if not original_aux_logits: - model.aux_logits = False - model.AuxLogits = None - - return model diff --git a/torchvision/prototype/models/mnasnet.py b/torchvision/prototype/models/mnasnet.py deleted file mode 100644 index c48e34a7be5..00000000000 --- a/torchvision/prototype/models/mnasnet.py +++ /dev/null @@ -1,113 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.mnasnet import MNASNet -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = [ - "MNASNet", - "MNASNet0_5_Weights", - "MNASNet0_75_Weights", - "MNASNet1_0_Weights", - "MNASNet1_3_Weights", - "mnasnet0_5", - "mnasnet0_75", - "mnasnet1_0", - "mnasnet1_3", -] - - -_COMMON_META = { - "task": "image_classification", - "architecture": "MNASNet", - "publication_year": 2018, - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/1e100/mnasnet_trainer", -} - - -class MNASNet0_5_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 2218512, - "acc@1": 67.734, - "acc@5": 87.490, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class MNASNet0_75_Weights(WeightsEnum): - # If a default model is added here the corresponding changes need to be done in mnasnet0_75 - pass - - -class MNASNet1_0_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 4383312, - "acc@1": 73.456, - "acc@5": 91.510, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class MNASNet1_3_Weights(WeightsEnum): - # If a default model is added here the corresponding changes need to be done in mnasnet1_3 - pass - - -def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> MNASNet: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = MNASNet(alpha, **kwargs) - - if weights: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -@handle_legacy_interface(weights=("pretrained", MNASNet0_5_Weights.IMAGENET1K_V1)) -def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: - weights = MNASNet0_5_Weights.verify(weights) - - return _mnasnet(0.5, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", None)) -def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: - weights = MNASNet0_75_Weights.verify(weights) - - return _mnasnet(0.75, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", MNASNet1_0_Weights.IMAGENET1K_V1)) -def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: - weights = MNASNet1_0_Weights.verify(weights) - - return _mnasnet(1.0, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", None)) -def mnasnet1_3(*, weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: - weights = MNASNet1_3_Weights.verify(weights) - - return _mnasnet(1.3, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/mobilenet.py b/torchvision/prototype/models/mobilenet.py deleted file mode 100644 index 0a270d14d3a..00000000000 --- a/torchvision/prototype/models/mobilenet.py +++ /dev/null @@ -1,6 +0,0 @@ -from .mobilenetv2 import * # noqa: F401, F403 -from .mobilenetv3 import * # noqa: F401, F403 -from .mobilenetv2 import __all__ as mv2_all -from .mobilenetv3 import __all__ as mv3_all - -__all__ = mv2_all + mv3_all diff --git a/torchvision/prototype/models/mobilenetv2.py b/torchvision/prototype/models/mobilenetv2.py deleted file mode 100644 index 71b412898fe..00000000000 --- a/torchvision/prototype/models/mobilenetv2.py +++ /dev/null @@ -1,66 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.mobilenetv2 import MobileNetV2 -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"] - - -_COMMON_META = { - "task": "image_classification", - "architecture": "MobileNetV2", - "publication_year": 2018, - "num_params": 3504872, - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, -} - - -class MobileNet_V2_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2", - "acc@1": 71.878, - "acc@5": 90.286, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning", - "acc@1": 72.154, - "acc@5": 90.822, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -@handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.IMAGENET1K_V1)) -def mobilenet_v2( - *, weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any -) -> MobileNetV2: - weights = MobileNet_V2_Weights.verify(weights) - - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = MobileNetV2(**kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model diff --git a/torchvision/prototype/models/mobilenetv3.py b/torchvision/prototype/models/mobilenetv3.py deleted file mode 100644 index aaf9c2c85a4..00000000000 --- a/torchvision/prototype/models/mobilenetv3.py +++ /dev/null @@ -1,109 +0,0 @@ -from functools import partial -from typing import Any, Optional, List - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.mobilenetv3 import MobileNetV3, _mobilenet_v3_conf, InvertedResidualConfig -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = [ - "MobileNetV3", - "MobileNet_V3_Large_Weights", - "MobileNet_V3_Small_Weights", - "mobilenet_v3_large", - "mobilenet_v3_small", -] - - -def _mobilenet_v3( - inverted_residual_setting: List[InvertedResidualConfig], - last_channel: int, - weights: Optional[WeightsEnum], - progress: bool, - **kwargs: Any, -) -> MobileNetV3: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -_COMMON_META = { - "task": "image_classification", - "architecture": "MobileNetV3", - "publication_year": 2019, - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, -} - - -class MobileNet_V3_Large_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 5483032, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small", - "acc@1": 74.042, - "acc@5": 91.340, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 5483032, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning", - "acc@1": 75.274, - "acc@5": 92.566, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class MobileNet_V3_Small_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 2542856, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small", - "acc@1": 67.668, - "acc@5": 87.402, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Large_Weights.IMAGENET1K_V1)) -def mobilenet_v3_large( - *, weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any -) -> MobileNetV3: - weights = MobileNet_V3_Large_Weights.verify(weights) - - inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs) - return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Small_Weights.IMAGENET1K_V1)) -def mobilenet_v3_small( - *, weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any -) -> MobileNetV3: - weights = MobileNet_V3_Small_Weights.verify(weights) - - inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_small", **kwargs) - return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/optical_flow/raft.py b/torchvision/prototype/models/optical_flow/raft.py index 24e87f3d4f9..33e3243c2a0 100644 --- a/torchvision/prototype/models/optical_flow/raft.py +++ b/torchvision/prototype/models/optical_flow/raft.py @@ -2,14 +2,12 @@ from torch.nn.modules.batchnorm import BatchNorm2d from torch.nn.modules.instancenorm import InstanceNorm2d +from torchvision.models._api import Weights +from torchvision.models._api import WeightsEnum +from torchvision.models._utils import handle_legacy_interface from torchvision.models.optical_flow import RAFT from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock -from torchvision.prototype.transforms import OpticalFlowEval -from torchvision.transforms.functional import InterpolationMode - -from .._api import WeightsEnum -from .._api import Weights -from .._utils import handle_legacy_interface +from torchvision.transforms import OpticalFlowEval, InterpolationMode __all__ = ( diff --git a/torchvision/prototype/models/quantization/__init__.py b/torchvision/prototype/models/quantization/__init__.py deleted file mode 100644 index da8bbba3567..00000000000 --- a/torchvision/prototype/models/quantization/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .googlenet import * -from .inception import * -from .mobilenet import * -from .resnet import * -from .shufflenetv2 import * diff --git a/torchvision/prototype/models/quantization/googlenet.py b/torchvision/prototype/models/quantization/googlenet.py deleted file mode 100644 index cca6ba25060..00000000000 --- a/torchvision/prototype/models/quantization/googlenet.py +++ /dev/null @@ -1,94 +0,0 @@ -import warnings -from functools import partial -from typing import Any, Optional, Union - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.quantization.googlenet import ( - QuantizableGoogLeNet, - _replace_relu, - quantize_model, -) -from .._api import WeightsEnum, Weights -from .._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param -from ..googlenet import GoogLeNet_Weights - - -__all__ = [ - "QuantizableGoogLeNet", - "GoogLeNet_QuantizedWeights", - "googlenet", -] - - -class GoogLeNet_QuantizedWeights(WeightsEnum): - IMAGENET1K_FBGEMM_V1 = Weights( - url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - "task": "image_classification", - "architecture": "GoogLeNet", - "publication_year": 2014, - "num_params": 6624904, - "size": (224, 224), - "min_size": (15, 15), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "backend": "fbgemm", - "quantization": "ptq", - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", - "unquantized": GoogLeNet_Weights.IMAGENET1K_V1, - "acc@1": 69.826, - "acc@5": 89.404, - }, - ) - DEFAULT = IMAGENET1K_FBGEMM_V1 - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: GoogLeNet_QuantizedWeights.IMAGENET1K_FBGEMM_V1 - if kwargs.get("quantize", False) - else GoogLeNet_Weights.IMAGENET1K_V1, - ) -) -def googlenet( - *, - weights: Optional[Union[GoogLeNet_QuantizedWeights, GoogLeNet_Weights]] = None, - progress: bool = True, - quantize: bool = False, - **kwargs: Any, -) -> QuantizableGoogLeNet: - weights = (GoogLeNet_QuantizedWeights if quantize else GoogLeNet_Weights).verify(weights) - - original_aux_logits = kwargs.get("aux_logits", False) - if weights is not None: - if "transform_input" not in kwargs: - _ovewrite_named_param(kwargs, "transform_input", True) - _ovewrite_named_param(kwargs, "aux_logits", True) - _ovewrite_named_param(kwargs, "init_weights", False) - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - if "backend" in weights.meta: - _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) - backend = kwargs.pop("backend", "fbgemm") - - model = QuantizableGoogLeNet(**kwargs) - _replace_relu(model) - if quantize: - quantize_model(model, backend) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - if not original_aux_logits: - model.aux_logits = False - model.aux1 = None # type: ignore[assignment] - model.aux2 = None # type: ignore[assignment] - else: - warnings.warn( - "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them" - ) - - return model diff --git a/torchvision/prototype/models/quantization/inception.py b/torchvision/prototype/models/quantization/inception.py deleted file mode 100644 index 2639b7de14f..00000000000 --- a/torchvision/prototype/models/quantization/inception.py +++ /dev/null @@ -1,90 +0,0 @@ -from functools import partial -from typing import Any, Optional, Union - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.quantization.inception import ( - QuantizableInception3, - _replace_relu, - quantize_model, -) -from .._api import WeightsEnum, Weights -from .._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param -from ..inception import Inception_V3_Weights - - -__all__ = [ - "QuantizableInception3", - "Inception_V3_QuantizedWeights", - "inception_v3", -] - - -class Inception_V3_QuantizedWeights(WeightsEnum): - IMAGENET1K_FBGEMM_V1 = Weights( - url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth", - transforms=partial(ImageClassificationEval, crop_size=299, resize_size=342), - meta={ - "task": "image_classification", - "architecture": "InceptionV3", - "publication_year": 2015, - "num_params": 27161264, - "size": (299, 299), - "min_size": (75, 75), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "backend": "fbgemm", - "quantization": "ptq", - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", - "unquantized": Inception_V3_Weights.IMAGENET1K_V1, - "acc@1": 77.176, - "acc@5": 93.354, - }, - ) - DEFAULT = IMAGENET1K_FBGEMM_V1 - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: Inception_V3_QuantizedWeights.IMAGENET1K_FBGEMM_V1 - if kwargs.get("quantize", False) - else Inception_V3_Weights.IMAGENET1K_V1, - ) -) -def inception_v3( - *, - weights: Optional[Union[Inception_V3_QuantizedWeights, Inception_V3_Weights]] = None, - progress: bool = True, - quantize: bool = False, - **kwargs: Any, -) -> QuantizableInception3: - weights = (Inception_V3_QuantizedWeights if quantize else Inception_V3_Weights).verify(weights) - - original_aux_logits = kwargs.get("aux_logits", False) - if weights is not None: - if "transform_input" not in kwargs: - _ovewrite_named_param(kwargs, "transform_input", True) - _ovewrite_named_param(kwargs, "aux_logits", True) - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - if "backend" in weights.meta: - _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) - backend = kwargs.pop("backend", "fbgemm") - - model = QuantizableInception3(**kwargs) - _replace_relu(model) - if quantize: - quantize_model(model, backend) - - if weights is not None: - if quantize and not original_aux_logits: - model.aux_logits = False - model.AuxLogits = None - model.load_state_dict(weights.get_state_dict(progress=progress)) - if not quantize and not original_aux_logits: - model.aux_logits = False - model.AuxLogits = None - - return model diff --git a/torchvision/prototype/models/quantization/mobilenet.py b/torchvision/prototype/models/quantization/mobilenet.py deleted file mode 100644 index 0a270d14d3a..00000000000 --- a/torchvision/prototype/models/quantization/mobilenet.py +++ /dev/null @@ -1,6 +0,0 @@ -from .mobilenetv2 import * # noqa: F401, F403 -from .mobilenetv3 import * # noqa: F401, F403 -from .mobilenetv2 import __all__ as mv2_all -from .mobilenetv3 import __all__ as mv3_all - -__all__ = mv2_all + mv3_all diff --git a/torchvision/prototype/models/quantization/mobilenetv2.py b/torchvision/prototype/models/quantization/mobilenetv2.py deleted file mode 100644 index a9789583fe6..00000000000 --- a/torchvision/prototype/models/quantization/mobilenetv2.py +++ /dev/null @@ -1,81 +0,0 @@ -from functools import partial -from typing import Any, Optional, Union - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.quantization.mobilenetv2 import ( - QuantizableInvertedResidual, - QuantizableMobileNetV2, - _replace_relu, - quantize_model, -) -from .._api import WeightsEnum, Weights -from .._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param -from ..mobilenetv2 import MobileNet_V2_Weights - - -__all__ = [ - "QuantizableMobileNetV2", - "MobileNet_V2_QuantizedWeights", - "mobilenet_v2", -] - - -class MobileNet_V2_QuantizedWeights(WeightsEnum): - IMAGENET1K_QNNPACK_V1 = Weights( - url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - "task": "image_classification", - "architecture": "MobileNetV2", - "publication_year": 2018, - "num_params": 3504872, - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "backend": "qnnpack", - "quantization": "qat", - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv2", - "unquantized": MobileNet_V2_Weights.IMAGENET1K_V1, - "acc@1": 71.658, - "acc@5": 90.150, - }, - ) - DEFAULT = IMAGENET1K_QNNPACK_V1 - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: MobileNet_V2_QuantizedWeights.IMAGENET1K_QNNPACK_V1 - if kwargs.get("quantize", False) - else MobileNet_V2_Weights.IMAGENET1K_V1, - ) -) -def mobilenet_v2( - *, - weights: Optional[Union[MobileNet_V2_QuantizedWeights, MobileNet_V2_Weights]] = None, - progress: bool = True, - quantize: bool = False, - **kwargs: Any, -) -> QuantizableMobileNetV2: - weights = (MobileNet_V2_QuantizedWeights if quantize else MobileNet_V2_Weights).verify(weights) - - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - if "backend" in weights.meta: - _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) - backend = kwargs.pop("backend", "qnnpack") - - model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs) - _replace_relu(model) - if quantize: - quantize_model(model, backend) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model diff --git a/torchvision/prototype/models/quantization/mobilenetv3.py b/torchvision/prototype/models/quantization/mobilenetv3.py deleted file mode 100644 index 915308d948f..00000000000 --- a/torchvision/prototype/models/quantization/mobilenetv3.py +++ /dev/null @@ -1,101 +0,0 @@ -from functools import partial -from typing import Any, List, Optional, Union - -import torch -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.quantization.mobilenetv3 import ( - InvertedResidualConfig, - QuantizableInvertedResidual, - QuantizableMobileNetV3, - _replace_relu, -) -from .._api import WeightsEnum, Weights -from .._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param -from ..mobilenetv3 import MobileNet_V3_Large_Weights, _mobilenet_v3_conf - - -__all__ = [ - "QuantizableMobileNetV3", - "MobileNet_V3_Large_QuantizedWeights", - "mobilenet_v3_large", -] - - -def _mobilenet_v3_model( - inverted_residual_setting: List[InvertedResidualConfig], - last_channel: int, - weights: Optional[WeightsEnum], - progress: bool, - quantize: bool, - **kwargs: Any, -) -> QuantizableMobileNetV3: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - if "backend" in weights.meta: - _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) - backend = kwargs.pop("backend", "qnnpack") - - model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs) - _replace_relu(model) - - if quantize: - model.fuse_model(is_qat=True) - model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend) - torch.ao.quantization.prepare_qat(model, inplace=True) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - if quantize: - torch.ao.quantization.convert(model, inplace=True) - model.eval() - - return model - - -class MobileNet_V3_Large_QuantizedWeights(WeightsEnum): - IMAGENET1K_QNNPACK_V1 = Weights( - url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - "task": "image_classification", - "architecture": "MobileNetV3", - "publication_year": 2019, - "num_params": 5483032, - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "backend": "qnnpack", - "quantization": "qat", - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv3", - "unquantized": MobileNet_V3_Large_Weights.IMAGENET1K_V1, - "acc@1": 73.004, - "acc@5": 90.858, - }, - ) - DEFAULT = IMAGENET1K_QNNPACK_V1 - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: MobileNet_V3_Large_QuantizedWeights.IMAGENET1K_QNNPACK_V1 - if kwargs.get("quantize", False) - else MobileNet_V3_Large_Weights.IMAGENET1K_V1, - ) -) -def mobilenet_v3_large( - *, - weights: Optional[Union[MobileNet_V3_Large_QuantizedWeights, MobileNet_V3_Large_Weights]] = None, - progress: bool = True, - quantize: bool = False, - **kwargs: Any, -) -> QuantizableMobileNetV3: - weights = (MobileNet_V3_Large_QuantizedWeights if quantize else MobileNet_V3_Large_Weights).verify(weights) - - inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs) - return _mobilenet_v3_model(inverted_residual_setting, last_channel, weights, progress, quantize, **kwargs) diff --git a/torchvision/prototype/models/quantization/resnet.py b/torchvision/prototype/models/quantization/resnet.py deleted file mode 100644 index 9e2e29db0bf..00000000000 --- a/torchvision/prototype/models/quantization/resnet.py +++ /dev/null @@ -1,204 +0,0 @@ -from functools import partial -from typing import Any, List, Optional, Type, Union - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.quantization.resnet import ( - QuantizableBasicBlock, - QuantizableBottleneck, - QuantizableResNet, - _replace_relu, - quantize_model, -) -from .._api import WeightsEnum, Weights -from .._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param -from ..resnet import ResNet18_Weights, ResNet50_Weights, ResNeXt101_32X8D_Weights - - -__all__ = [ - "QuantizableResNet", - "ResNet18_QuantizedWeights", - "ResNet50_QuantizedWeights", - "ResNeXt101_32X8D_QuantizedWeights", - "resnet18", - "resnet50", - "resnext101_32x8d", -] - - -def _resnet( - block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]], - layers: List[int], - weights: Optional[WeightsEnum], - progress: bool, - quantize: bool, - **kwargs: Any, -) -> QuantizableResNet: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - if "backend" in weights.meta: - _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) - backend = kwargs.pop("backend", "fbgemm") - - model = QuantizableResNet(block, layers, **kwargs) - _replace_relu(model) - if quantize: - quantize_model(model, backend) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -_COMMON_META = { - "task": "image_classification", - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "backend": "fbgemm", - "quantization": "ptq", - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", -} - - -class ResNet18_QuantizedWeights(WeightsEnum): - IMAGENET1K_FBGEMM_V1 = Weights( - url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 11689512, - "unquantized": ResNet18_Weights.IMAGENET1K_V1, - "acc@1": 69.494, - "acc@5": 88.882, - }, - ) - DEFAULT = IMAGENET1K_FBGEMM_V1 - - -class ResNet50_QuantizedWeights(WeightsEnum): - IMAGENET1K_FBGEMM_V1 = Weights( - url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 25557032, - "unquantized": ResNet50_Weights.IMAGENET1K_V1, - "acc@1": 75.920, - "acc@5": 92.814, - }, - ) - IMAGENET1K_FBGEMM_V2 = Weights( - url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 25557032, - "unquantized": ResNet50_Weights.IMAGENET1K_V2, - "acc@1": 80.282, - "acc@5": 94.976, - }, - ) - DEFAULT = IMAGENET1K_FBGEMM_V2 - - -class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): - IMAGENET1K_FBGEMM_V1 = Weights( - url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNeXt", - "publication_year": 2016, - "num_params": 88791336, - "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V1, - "acc@1": 78.986, - "acc@5": 94.480, - }, - ) - IMAGENET1K_FBGEMM_V2 = Weights( - url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "architecture": "ResNeXt", - "publication_year": 2016, - "num_params": 88791336, - "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V2, - "acc@1": 82.574, - "acc@5": 96.132, - }, - ) - DEFAULT = IMAGENET1K_FBGEMM_V2 - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: ResNet18_QuantizedWeights.IMAGENET1K_FBGEMM_V1 - if kwargs.get("quantize", False) - else ResNet18_Weights.IMAGENET1K_V1, - ) -) -def resnet18( - *, - weights: Optional[Union[ResNet18_QuantizedWeights, ResNet18_Weights]] = None, - progress: bool = True, - quantize: bool = False, - **kwargs: Any, -) -> QuantizableResNet: - weights = (ResNet18_QuantizedWeights if quantize else ResNet18_Weights).verify(weights) - - return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs) - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1 - if kwargs.get("quantize", False) - else ResNet50_Weights.IMAGENET1K_V1, - ) -) -def resnet50( - *, - weights: Optional[Union[ResNet50_QuantizedWeights, ResNet50_Weights]] = None, - progress: bool = True, - quantize: bool = False, - **kwargs: Any, -) -> QuantizableResNet: - weights = (ResNet50_QuantizedWeights if quantize else ResNet50_Weights).verify(weights) - - return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs) - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V1 - if kwargs.get("quantize", False) - else ResNeXt101_32X8D_Weights.IMAGENET1K_V1, - ) -) -def resnext101_32x8d( - *, - weights: Optional[Union[ResNeXt101_32X8D_QuantizedWeights, ResNeXt101_32X8D_Weights]] = None, - progress: bool = True, - quantize: bool = False, - **kwargs: Any, -) -> QuantizableResNet: - weights = (ResNeXt101_32X8D_QuantizedWeights if quantize else ResNeXt101_32X8D_Weights).verify(weights) - - _ovewrite_named_param(kwargs, "groups", 32) - _ovewrite_named_param(kwargs, "width_per_group", 8) - return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs) diff --git a/torchvision/prototype/models/quantization/shufflenetv2.py b/torchvision/prototype/models/quantization/shufflenetv2.py deleted file mode 100644 index e21349ff8e0..00000000000 --- a/torchvision/prototype/models/quantization/shufflenetv2.py +++ /dev/null @@ -1,136 +0,0 @@ -from functools import partial -from typing import Any, List, Optional, Union - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.quantization.shufflenetv2 import ( - QuantizableShuffleNetV2, - _replace_relu, - quantize_model, -) -from .._api import WeightsEnum, Weights -from .._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param -from ..shufflenetv2 import ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights - - -__all__ = [ - "QuantizableShuffleNetV2", - "ShuffleNet_V2_X0_5_QuantizedWeights", - "ShuffleNet_V2_X1_0_QuantizedWeights", - "shufflenet_v2_x0_5", - "shufflenet_v2_x1_0", -] - - -def _shufflenetv2( - stages_repeats: List[int], - stages_out_channels: List[int], - *, - weights: Optional[WeightsEnum], - progress: bool, - quantize: bool, - **kwargs: Any, -) -> QuantizableShuffleNetV2: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - if "backend" in weights.meta: - _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) - backend = kwargs.pop("backend", "fbgemm") - - model = QuantizableShuffleNetV2(stages_repeats, stages_out_channels, **kwargs) - _replace_relu(model) - if quantize: - quantize_model(model, backend) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -_COMMON_META = { - "task": "image_classification", - "architecture": "ShuffleNetV2", - "publication_year": 2018, - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "backend": "fbgemm", - "quantization": "ptq", - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", -} - - -class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum): - IMAGENET1K_FBGEMM_V1 = Weights( - url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 1366792, - "unquantized": ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1, - "acc@1": 57.972, - "acc@5": 79.780, - }, - ) - DEFAULT = IMAGENET1K_FBGEMM_V1 - - -class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum): - IMAGENET1K_FBGEMM_V1 = Weights( - url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 2278604, - "unquantized": ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1, - "acc@1": 68.360, - "acc@5": 87.582, - }, - ) - DEFAULT = IMAGENET1K_FBGEMM_V1 - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: ShuffleNet_V2_X0_5_QuantizedWeights.IMAGENET1K_FBGEMM_V1 - if kwargs.get("quantize", False) - else ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1, - ) -) -def shufflenet_v2_x0_5( - *, - weights: Optional[Union[ShuffleNet_V2_X0_5_QuantizedWeights, ShuffleNet_V2_X0_5_Weights]] = None, - progress: bool = True, - quantize: bool = False, - **kwargs: Any, -) -> QuantizableShuffleNetV2: - weights = (ShuffleNet_V2_X0_5_QuantizedWeights if quantize else ShuffleNet_V2_X0_5_Weights).verify(weights) - return _shufflenetv2( - [4, 8, 4], [24, 48, 96, 192, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs - ) - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: ShuffleNet_V2_X1_0_QuantizedWeights.IMAGENET1K_FBGEMM_V1 - if kwargs.get("quantize", False) - else ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1, - ) -) -def shufflenet_v2_x1_0( - *, - weights: Optional[Union[ShuffleNet_V2_X1_0_QuantizedWeights, ShuffleNet_V2_X1_0_Weights]] = None, - progress: bool = True, - quantize: bool = False, - **kwargs: Any, -) -> QuantizableShuffleNetV2: - weights = (ShuffleNet_V2_X1_0_QuantizedWeights if quantize else ShuffleNet_V2_X1_0_Weights).verify(weights) - return _shufflenetv2( - [4, 8, 4], [24, 116, 232, 464, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs - ) diff --git a/torchvision/prototype/models/regnet.py b/torchvision/prototype/models/regnet.py deleted file mode 100644 index d5e2b535532..00000000000 --- a/torchvision/prototype/models/regnet.py +++ /dev/null @@ -1,575 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torch import nn -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.regnet import RegNet, BlockParams -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = [ - "RegNet", - "RegNet_Y_400MF_Weights", - "RegNet_Y_800MF_Weights", - "RegNet_Y_1_6GF_Weights", - "RegNet_Y_3_2GF_Weights", - "RegNet_Y_8GF_Weights", - "RegNet_Y_16GF_Weights", - "RegNet_Y_32GF_Weights", - "RegNet_Y_128GF_Weights", - "RegNet_X_400MF_Weights", - "RegNet_X_800MF_Weights", - "RegNet_X_1_6GF_Weights", - "RegNet_X_3_2GF_Weights", - "RegNet_X_8GF_Weights", - "RegNet_X_16GF_Weights", - "RegNet_X_32GF_Weights", - "regnet_y_400mf", - "regnet_y_800mf", - "regnet_y_1_6gf", - "regnet_y_3_2gf", - "regnet_y_8gf", - "regnet_y_16gf", - "regnet_y_32gf", - "regnet_y_128gf", - "regnet_x_400mf", - "regnet_x_800mf", - "regnet_x_1_6gf", - "regnet_x_3_2gf", - "regnet_x_8gf", - "regnet_x_16gf", - "regnet_x_32gf", -] - -_COMMON_META = { - "task": "image_classification", - "architecture": "RegNet", - "publication_year": 2020, - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, -} - - -def _regnet( - block_params: BlockParams, - weights: Optional[WeightsEnum], - progress: bool, - **kwargs: Any, -) -> RegNet: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - norm_layer = kwargs.pop("norm_layer", partial(nn.BatchNorm2d, eps=1e-05, momentum=0.1)) - model = RegNet(block_params, norm_layer=norm_layer, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -class RegNet_Y_400MF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 4344144, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", - "acc@1": 74.046, - "acc@5": 91.716, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_y_400mf-e6988f5f.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 4344144, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 75.804, - "acc@5": 92.742, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_Y_800MF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 6432512, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", - "acc@1": 76.420, - "acc@5": 93.136, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_y_800mf-58fc7688.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 6432512, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 78.828, - "acc@5": 94.502, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_Y_1_6GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 11202430, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", - "acc@1": 77.950, - "acc@5": 93.966, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_y_1_6gf-0d7bc02a.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 11202430, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 80.876, - "acc@5": 95.444, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_Y_3_2GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 19436338, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", - "acc@1": 78.948, - "acc@5": 94.576, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_y_3_2gf-9180c971.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 19436338, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 81.982, - "acc@5": 95.972, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_Y_8GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 39381472, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", - "acc@1": 80.032, - "acc@5": 95.048, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_y_8gf-dc2b1b54.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 39381472, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 82.828, - "acc@5": 96.330, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_Y_16GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 83590140, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models", - "acc@1": 80.424, - "acc@5": 95.240, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_y_16gf-3e4a00f9.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 83590140, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 82.886, - "acc@5": 96.328, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_Y_32GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 145046770, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models", - "acc@1": 80.878, - "acc@5": 95.340, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_y_32gf-8db6d4b5.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 145046770, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 83.368, - "acc@5": 96.498, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_Y_128GF_Weights(WeightsEnum): - # weights are not available yet. - pass - - -class RegNet_X_400MF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 5495976, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", - "acc@1": 72.834, - "acc@5": 90.950, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_x_400mf-62229a5f.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 5495976, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", - "acc@1": 74.864, - "acc@5": 92.322, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_X_800MF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 7259656, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", - "acc@1": 75.212, - "acc@5": 92.348, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_x_800mf-94a99ebd.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 7259656, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", - "acc@1": 77.522, - "acc@5": 93.826, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_X_1_6GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 9190136, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", - "acc@1": 77.040, - "acc@5": 93.440, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_x_1_6gf-a12f2b72.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 9190136, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", - "acc@1": 79.668, - "acc@5": 94.922, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_X_3_2GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 15296552, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", - "acc@1": 78.364, - "acc@5": 93.992, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_x_3_2gf-7071aa85.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 15296552, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 81.196, - "acc@5": 95.430, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_X_8GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 39572648, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", - "acc@1": 79.344, - "acc@5": 94.686, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_x_8gf-2b70d774.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 39572648, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 81.682, - "acc@5": 95.678, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_X_16GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 54278536, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", - "acc@1": 80.058, - "acc@5": 94.944, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_x_16gf-ba3796d7.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 54278536, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 82.716, - "acc@5": 96.196, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_X_32GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 107811560, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models", - "acc@1": 80.622, - "acc@5": 95.248, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_x_32gf-6eb8fdc6.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 107811560, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 83.014, - "acc@5": 96.288, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -@handle_legacy_interface(weights=("pretrained", RegNet_Y_400MF_Weights.IMAGENET1K_V1)) -def regnet_y_400mf(*, weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_Y_400MF_Weights.verify(weights) - - params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_Y_800MF_Weights.IMAGENET1K_V1)) -def regnet_y_800mf(*, weights: Optional[RegNet_Y_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_Y_800MF_Weights.verify(weights) - - params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_Y_1_6GF_Weights.IMAGENET1K_V1)) -def regnet_y_1_6gf(*, weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_Y_1_6GF_Weights.verify(weights) - - params = BlockParams.from_init_params( - depth=27, w_0=48, w_a=20.71, w_m=2.65, group_width=24, se_ratio=0.25, **kwargs - ) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_Y_3_2GF_Weights.IMAGENET1K_V1)) -def regnet_y_3_2gf(*, weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_Y_3_2GF_Weights.verify(weights) - - params = BlockParams.from_init_params( - depth=21, w_0=80, w_a=42.63, w_m=2.66, group_width=24, se_ratio=0.25, **kwargs - ) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_Y_8GF_Weights.IMAGENET1K_V1)) -def regnet_y_8gf(*, weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_Y_8GF_Weights.verify(weights) - - params = BlockParams.from_init_params( - depth=17, w_0=192, w_a=76.82, w_m=2.19, group_width=56, se_ratio=0.25, **kwargs - ) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_Y_16GF_Weights.IMAGENET1K_V1)) -def regnet_y_16gf(*, weights: Optional[RegNet_Y_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_Y_16GF_Weights.verify(weights) - - params = BlockParams.from_init_params( - depth=18, w_0=200, w_a=106.23, w_m=2.48, group_width=112, se_ratio=0.25, **kwargs - ) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_Y_32GF_Weights.IMAGENET1K_V1)) -def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_Y_32GF_Weights.verify(weights) - - params = BlockParams.from_init_params( - depth=20, w_0=232, w_a=115.89, w_m=2.53, group_width=232, se_ratio=0.25, **kwargs - ) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", None)) -def regnet_y_128gf(*, weights: Optional[RegNet_Y_128GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_Y_128GF_Weights.verify(weights) - - params = BlockParams.from_init_params( - depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25, **kwargs - ) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_X_400MF_Weights.IMAGENET1K_V1)) -def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_X_400MF_Weights.verify(weights) - - params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_X_800MF_Weights.IMAGENET1K_V1)) -def regnet_x_800mf(*, weights: Optional[RegNet_X_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_X_800MF_Weights.verify(weights) - - params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_X_1_6GF_Weights.IMAGENET1K_V1)) -def regnet_x_1_6gf(*, weights: Optional[RegNet_X_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_X_1_6GF_Weights.verify(weights) - - params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_X_3_2GF_Weights.IMAGENET1K_V1)) -def regnet_x_3_2gf(*, weights: Optional[RegNet_X_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_X_3_2GF_Weights.verify(weights) - - params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_X_8GF_Weights.IMAGENET1K_V1)) -def regnet_x_8gf(*, weights: Optional[RegNet_X_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_X_8GF_Weights.verify(weights) - - params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_X_16GF_Weights.IMAGENET1K_V1)) -def regnet_x_16gf(*, weights: Optional[RegNet_X_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_X_16GF_Weights.verify(weights) - - params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_X_32GF_Weights.IMAGENET1K_V1)) -def regnet_x_32gf(*, weights: Optional[RegNet_X_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_X_32GF_Weights.verify(weights) - - params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs) - return _regnet(params, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/resnet.py b/torchvision/prototype/models/resnet.py deleted file mode 100644 index 35e30c0e760..00000000000 --- a/torchvision/prototype/models/resnet.py +++ /dev/null @@ -1,381 +0,0 @@ -from functools import partial -from typing import Any, List, Optional, Type, Union - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.resnet import BasicBlock, Bottleneck, ResNet -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = [ - "ResNet", - "ResNet18_Weights", - "ResNet34_Weights", - "ResNet50_Weights", - "ResNet101_Weights", - "ResNet152_Weights", - "ResNeXt50_32X4D_Weights", - "ResNeXt101_32X8D_Weights", - "Wide_ResNet50_2_Weights", - "Wide_ResNet101_2_Weights", - "resnet18", - "resnet34", - "resnet50", - "resnet101", - "resnet152", - "resnext50_32x4d", - "resnext101_32x8d", - "wide_resnet50_2", - "wide_resnet101_2", -] - - -def _resnet( - block: Type[Union[BasicBlock, Bottleneck]], - layers: List[int], - weights: Optional[WeightsEnum], - progress: bool, - **kwargs: Any, -) -> ResNet: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = ResNet(block, layers, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -_COMMON_META = { - "task": "image_classification", - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, -} - - -class ResNet18_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/resnet18-f37072fd.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 11689512, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", - "acc@1": 69.758, - "acc@5": 89.078, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ResNet34_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/resnet34-b627a593.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 21797672, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", - "acc@1": 73.314, - "acc@5": 91.420, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ResNet50_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/resnet50-0676ba61.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 25557032, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", - "acc@1": 76.130, - "acc@5": 92.862, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/resnet50-11ad3fa6.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 25557032, - "recipe": "https://github.com/pytorch/vision/issues/3995#issuecomment-1013906621", - "acc@1": 80.858, - "acc@5": 95.434, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class ResNet101_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/resnet101-63fe2227.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 44549160, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", - "acc@1": 77.374, - "acc@5": 93.546, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/resnet101-cd907fc2.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 44549160, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 81.886, - "acc@5": 95.780, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class ResNet152_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/resnet152-394f9c45.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 60192808, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", - "acc@1": 78.312, - "acc@5": 94.046, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/resnet152-f82ba261.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 60192808, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 82.284, - "acc@5": 96.002, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class ResNeXt50_32X4D_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNeXt", - "publication_year": 2016, - "num_params": 25028904, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext", - "acc@1": 77.618, - "acc@5": 93.698, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "architecture": "ResNeXt", - "publication_year": 2016, - "num_params": 25028904, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 81.198, - "acc@5": 95.340, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class ResNeXt101_32X8D_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNeXt", - "publication_year": 2016, - "num_params": 88791336, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext", - "acc@1": 79.312, - "acc@5": 94.526, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "architecture": "ResNeXt", - "publication_year": 2016, - "num_params": 88791336, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", - "acc@1": 82.834, - "acc@5": 96.228, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class Wide_ResNet50_2_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "WideResNet", - "publication_year": 2016, - "num_params": 68883240, - "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439", - "acc@1": 78.468, - "acc@5": 94.086, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "architecture": "WideResNet", - "publication_year": 2016, - "num_params": 68883240, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", - "acc@1": 81.602, - "acc@5": 95.758, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class Wide_ResNet101_2_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "WideResNet", - "publication_year": 2016, - "num_params": 126886696, - "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439", - "acc@1": 78.848, - "acc@5": 94.284, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "architecture": "WideResNet", - "publication_year": 2016, - "num_params": 126886696, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 82.510, - "acc@5": 96.020, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -@handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1)) -def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: - weights = ResNet18_Weights.verify(weights) - - return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ResNet34_Weights.IMAGENET1K_V1)) -def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: - weights = ResNet34_Weights.verify(weights) - - return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ResNet50_Weights.IMAGENET1K_V1)) -def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: - weights = ResNet50_Weights.verify(weights) - - return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ResNet101_Weights.IMAGENET1K_V1)) -def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: - weights = ResNet101_Weights.verify(weights) - - return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ResNet152_Weights.IMAGENET1K_V1)) -def resnet152(*, weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: - weights = ResNet152_Weights.verify(weights) - - return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ResNeXt50_32X4D_Weights.IMAGENET1K_V1)) -def resnext50_32x4d( - *, weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: bool = True, **kwargs: Any -) -> ResNet: - weights = ResNeXt50_32X4D_Weights.verify(weights) - - _ovewrite_named_param(kwargs, "groups", 32) - _ovewrite_named_param(kwargs, "width_per_group", 4) - return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ResNeXt101_32X8D_Weights.IMAGENET1K_V1)) -def resnext101_32x8d( - *, weights: Optional[ResNeXt101_32X8D_Weights] = None, progress: bool = True, **kwargs: Any -) -> ResNet: - weights = ResNeXt101_32X8D_Weights.verify(weights) - - _ovewrite_named_param(kwargs, "groups", 32) - _ovewrite_named_param(kwargs, "width_per_group", 8) - return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", Wide_ResNet50_2_Weights.IMAGENET1K_V1)) -def wide_resnet50_2( - *, weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any -) -> ResNet: - weights = Wide_ResNet50_2_Weights.verify(weights) - - _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) - return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", Wide_ResNet101_2_Weights.IMAGENET1K_V1)) -def wide_resnet101_2( - *, weights: Optional[Wide_ResNet101_2_Weights] = None, progress: bool = True, **kwargs: Any -) -> ResNet: - weights = Wide_ResNet101_2_Weights.verify(weights) - - _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) - return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) diff --git a/torchvision/prototype/models/segmentation/deeplabv3.py b/torchvision/prototype/models/segmentation/deeplabv3.py index 7165078161f..2c8d7f6ad84 100644 --- a/torchvision/prototype/models/segmentation/deeplabv3.py +++ b/torchvision/prototype/models/segmentation/deeplabv3.py @@ -1,16 +1,13 @@ from functools import partial from typing import Any, Optional -from torchvision.prototype.transforms import SemanticSegmentationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet -from .._api import WeightsEnum, Weights -from .._meta import _VOC_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large -from ..resnet import resnet50, resnet101 -from ..resnet import ResNet50_Weights, ResNet101_Weights +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _VOC_CATEGORIES +from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param +from torchvision.models.mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large +from torchvision.models.resnet import resnet50, resnet101, ResNet50_Weights, ResNet101_Weights +from torchvision.models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet +from torchvision.transforms import SemanticSegmentationEval, InterpolationMode __all__ = [ diff --git a/torchvision/prototype/models/segmentation/fcn.py b/torchvision/prototype/models/segmentation/fcn.py index 1dfc251844f..e7b12621940 100644 --- a/torchvision/prototype/models/segmentation/fcn.py +++ b/torchvision/prototype/models/segmentation/fcn.py @@ -1,14 +1,12 @@ from functools import partial from typing import Any, Optional -from torchvision.prototype.transforms import SemanticSegmentationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.segmentation.fcn import FCN, _fcn_resnet -from .._api import WeightsEnum, Weights -from .._meta import _VOC_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..resnet import ResNet50_Weights, ResNet101_Weights, resnet50, resnet101 +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _VOC_CATEGORIES +from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param +from torchvision.models.resnet import ResNet50_Weights, ResNet101_Weights, resnet50, resnet101 +from torchvision.models.segmentation.fcn import FCN, _fcn_resnet +from torchvision.transforms import SemanticSegmentationEval, InterpolationMode __all__ = ["FCN", "FCN_ResNet50_Weights", "FCN_ResNet101_Weights", "fcn_resnet50", "fcn_resnet101"] diff --git a/torchvision/prototype/models/segmentation/lraspp.py b/torchvision/prototype/models/segmentation/lraspp.py index 2c0fa6f0aff..21c15373089 100644 --- a/torchvision/prototype/models/segmentation/lraspp.py +++ b/torchvision/prototype/models/segmentation/lraspp.py @@ -1,14 +1,12 @@ from functools import partial from typing import Any, Optional -from torchvision.prototype.transforms import SemanticSegmentationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3 -from .._api import WeightsEnum, Weights -from .._meta import _VOC_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _VOC_CATEGORIES +from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param +from torchvision.models.mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large +from torchvision.models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3 +from torchvision.transforms import SemanticSegmentationEval, InterpolationMode __all__ = ["LRASPP", "LRASPP_MobileNet_V3_Large_Weights", "lraspp_mobilenet_v3_large"] diff --git a/torchvision/prototype/models/shufflenetv2.py b/torchvision/prototype/models/shufflenetv2.py deleted file mode 100644 index 48047a70c60..00000000000 --- a/torchvision/prototype/models/shufflenetv2.py +++ /dev/null @@ -1,124 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.shufflenetv2 import ShuffleNetV2 -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = [ - "ShuffleNetV2", - "ShuffleNet_V2_X0_5_Weights", - "ShuffleNet_V2_X1_0_Weights", - "ShuffleNet_V2_X1_5_Weights", - "ShuffleNet_V2_X2_0_Weights", - "shufflenet_v2_x0_5", - "shufflenet_v2_x1_0", - "shufflenet_v2_x1_5", - "shufflenet_v2_x2_0", -] - - -def _shufflenetv2( - weights: Optional[WeightsEnum], - progress: bool, - *args: Any, - **kwargs: Any, -) -> ShuffleNetV2: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = ShuffleNetV2(*args, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -_COMMON_META = { - "task": "image_classification", - "architecture": "ShuffleNetV2", - "publication_year": 2018, - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/barrh/Shufflenet-v2-Pytorch/tree/v0.1.0", -} - - -class ShuffleNet_V2_X0_5_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 1366792, - "acc@1": 69.362, - "acc@5": 88.316, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ShuffleNet_V2_X1_0_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 2278604, - "acc@1": 60.552, - "acc@5": 81.746, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ShuffleNet_V2_X1_5_Weights(WeightsEnum): - pass - - -class ShuffleNet_V2_X2_0_Weights(WeightsEnum): - pass - - -@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1)) -def shufflenet_v2_x0_5( - *, weights: Optional[ShuffleNet_V2_X0_5_Weights] = None, progress: bool = True, **kwargs: Any -) -> ShuffleNetV2: - weights = ShuffleNet_V2_X0_5_Weights.verify(weights) - - return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1)) -def shufflenet_v2_x1_0( - *, weights: Optional[ShuffleNet_V2_X1_0_Weights] = None, progress: bool = True, **kwargs: Any -) -> ShuffleNetV2: - weights = ShuffleNet_V2_X1_0_Weights.verify(weights) - - return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) - - -@handle_legacy_interface(weights=("pretrained", None)) -def shufflenet_v2_x1_5( - *, weights: Optional[ShuffleNet_V2_X1_5_Weights] = None, progress: bool = True, **kwargs: Any -) -> ShuffleNetV2: - weights = ShuffleNet_V2_X1_5_Weights.verify(weights) - - return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) - - -@handle_legacy_interface(weights=("pretrained", None)) -def shufflenet_v2_x2_0( - *, weights: Optional[ShuffleNet_V2_X2_0_Weights] = None, progress: bool = True, **kwargs: Any -) -> ShuffleNetV2: - weights = ShuffleNet_V2_X2_0_Weights.verify(weights) - - return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) diff --git a/torchvision/prototype/models/squeezenet.py b/torchvision/prototype/models/squeezenet.py deleted file mode 100644 index 7f6a034ed6c..00000000000 --- a/torchvision/prototype/models/squeezenet.py +++ /dev/null @@ -1,88 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.squeezenet import SqueezeNet -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = ["SqueezeNet", "SqueezeNet1_0_Weights", "SqueezeNet1_1_Weights", "squeezenet1_0", "squeezenet1_1"] - - -_COMMON_META = { - "task": "image_classification", - "architecture": "SqueezeNet", - "publication_year": 2016, - "size": (224, 224), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/pull/49#issuecomment-277560717", -} - - -class SqueezeNet1_0_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "min_size": (21, 21), - "num_params": 1248424, - "acc@1": 58.092, - "acc@5": 80.420, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class SqueezeNet1_1_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "min_size": (17, 17), - "num_params": 1235496, - "acc@1": 58.178, - "acc@5": 80.624, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -@handle_legacy_interface(weights=("pretrained", SqueezeNet1_0_Weights.IMAGENET1K_V1)) -def squeezenet1_0( - *, weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any -) -> SqueezeNet: - weights = SqueezeNet1_0_Weights.verify(weights) - - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = SqueezeNet("1_0", **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -@handle_legacy_interface(weights=("pretrained", SqueezeNet1_1_Weights.IMAGENET1K_V1)) -def squeezenet1_1( - *, weights: Optional[SqueezeNet1_1_Weights] = None, progress: bool = True, **kwargs: Any -) -> SqueezeNet: - weights = SqueezeNet1_1_Weights.verify(weights) - - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = SqueezeNet("1_1", **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model diff --git a/torchvision/prototype/models/vgg.py b/torchvision/prototype/models/vgg.py deleted file mode 100644 index 233c35418ed..00000000000 --- a/torchvision/prototype/models/vgg.py +++ /dev/null @@ -1,240 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.vgg import VGG, make_layers, cfgs -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = [ - "VGG", - "VGG11_Weights", - "VGG11_BN_Weights", - "VGG13_Weights", - "VGG13_BN_Weights", - "VGG16_Weights", - "VGG16_BN_Weights", - "VGG19_Weights", - "VGG19_BN_Weights", - "vgg11", - "vgg11_bn", - "vgg13", - "vgg13_bn", - "vgg16", - "vgg16_bn", - "vgg19", - "vgg19_bn", -] - - -def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> VGG: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - return model - - -_COMMON_META = { - "task": "image_classification", - "architecture": "VGG", - "publication_year": 2014, - "size": (224, 224), - "min_size": (32, 32), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg", -} - - -class VGG11_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vgg11-8a719046.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 132863336, - "acc@1": 69.020, - "acc@5": 88.628, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class VGG11_BN_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 132868840, - "acc@1": 70.370, - "acc@5": 89.810, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class VGG13_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vgg13-19584684.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 133047848, - "acc@1": 69.928, - "acc@5": 89.246, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class VGG13_BN_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 133053736, - "acc@1": 71.586, - "acc@5": 90.374, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class VGG16_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vgg16-397923af.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 138357544, - "acc@1": 71.592, - "acc@5": 90.382, - }, - ) - # We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the - # same input standardization method as the paper. Only the `features` weights have proper values, those on the - # `classifier` module are filled with nans. - IMAGENET1K_FEATURES = Weights( - url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth", - transforms=partial( - ImageClassificationEval, - crop_size=224, - mean=(0.48235, 0.45882, 0.40784), - std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0), - ), - meta={ - **_COMMON_META, - "num_params": 138357544, - "categories": None, - "recipe": "https://github.com/amdegroot/ssd.pytorch#training-ssd", - "acc@1": float("nan"), - "acc@5": float("nan"), - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class VGG16_BN_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 138365992, - "acc@1": 73.360, - "acc@5": 91.516, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class VGG19_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 143667240, - "acc@1": 72.376, - "acc@5": 90.876, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class VGG19_BN_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 143678248, - "acc@1": 74.218, - "acc@5": 91.842, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -@handle_legacy_interface(weights=("pretrained", VGG11_Weights.IMAGENET1K_V1)) -def vgg11(*, weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - weights = VGG11_Weights.verify(weights) - - return _vgg("A", False, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", VGG11_BN_Weights.IMAGENET1K_V1)) -def vgg11_bn(*, weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - weights = VGG11_BN_Weights.verify(weights) - - return _vgg("A", True, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", VGG13_Weights.IMAGENET1K_V1)) -def vgg13(*, weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - weights = VGG13_Weights.verify(weights) - - return _vgg("B", False, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", VGG13_BN_Weights.IMAGENET1K_V1)) -def vgg13_bn(*, weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - weights = VGG13_BN_Weights.verify(weights) - - return _vgg("B", True, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", VGG16_Weights.IMAGENET1K_V1)) -def vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - weights = VGG16_Weights.verify(weights) - - return _vgg("D", False, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", VGG16_BN_Weights.IMAGENET1K_V1)) -def vgg16_bn(*, weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - weights = VGG16_BN_Weights.verify(weights) - - return _vgg("D", True, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", VGG19_Weights.IMAGENET1K_V1)) -def vgg19(*, weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - weights = VGG19_Weights.verify(weights) - - return _vgg("E", False, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", VGG19_BN_Weights.IMAGENET1K_V1)) -def vgg19_bn(*, weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - weights = VGG19_BN_Weights.verify(weights) - - return _vgg("E", True, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/video/resnet.py b/torchvision/prototype/models/video/resnet.py index 790d254d266..0f4c0dd1dc9 100644 --- a/torchvision/prototype/models/video/resnet.py +++ b/torchvision/prototype/models/video/resnet.py @@ -2,10 +2,10 @@ from typing import Any, Callable, List, Optional, Sequence, Type, Union from torch import nn -from torchvision.prototype.transforms import VideoClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.video.resnet import ( +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _KINETICS400_CATEGORIES +from torchvision.models._utils import handle_legacy_interface, _ovewrite_named_param +from torchvision.models.video.resnet import ( BasicBlock, BasicStem, Bottleneck, @@ -15,9 +15,7 @@ R2Plus1dStem, VideoResNet, ) -from .._api import WeightsEnum, Weights -from .._meta import _KINETICS400_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param +from torchvision.transforms import VideoClassificationEval, InterpolationMode __all__ = [ diff --git a/torchvision/prototype/models/vision_transformer.py b/torchvision/prototype/models/vision_transformer.py deleted file mode 100644 index 468903b6b94..00000000000 --- a/torchvision/prototype/models/vision_transformer.py +++ /dev/null @@ -1,198 +0,0 @@ -# References: -# https://github.com/google-research/vision_transformer -# https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/vision_transformer.py - -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.vision_transformer import VisionTransformer, interpolate_embeddings # noqa: F401 -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - -__all__ = [ - "VisionTransformer", - "ViT_B_16_Weights", - "ViT_B_32_Weights", - "ViT_L_16_Weights", - "ViT_L_32_Weights", - "vit_b_16", - "vit_b_32", - "vit_l_16", - "vit_l_32", -] - - -_COMMON_META = { - "task": "image_classification", - "architecture": "ViT", - "publication_year": 2020, - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, -} - - -class ViT_B_16_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vit_b_16-c867db91.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 86567656, - "size": (224, 224), - "min_size": (224, 224), - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_16", - "acc@1": 81.072, - "acc@5": 95.318, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ViT_B_32_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 88224232, - "size": (224, 224), - "min_size": (224, 224), - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_32", - "acc@1": 75.912, - "acc@5": 92.466, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ViT_L_16_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=242), - meta={ - **_COMMON_META, - "num_params": 304326632, - "size": (224, 224), - "min_size": (224, 224), - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_16", - "acc@1": 79.662, - "acc@5": 94.638, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ViT_L_32_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vit_l_32-c7638314.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 306535400, - "size": (224, 224), - "min_size": (224, 224), - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_32", - "acc@1": 76.972, - "acc@5": 93.07, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -def _vision_transformer( - patch_size: int, - num_layers: int, - num_heads: int, - hidden_dim: int, - mlp_dim: int, - weights: Optional[WeightsEnum], - progress: bool, - **kwargs: Any, -) -> VisionTransformer: - image_size = kwargs.pop("image_size", 224) - - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = VisionTransformer( - image_size=image_size, - patch_size=patch_size, - num_layers=num_layers, - num_heads=num_heads, - hidden_dim=hidden_dim, - mlp_dim=mlp_dim, - **kwargs, - ) - - if weights: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -@handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1)) -def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: - weights = ViT_B_16_Weights.verify(weights) - - return _vision_transformer( - patch_size=16, - num_layers=12, - num_heads=12, - hidden_dim=768, - mlp_dim=3072, - weights=weights, - progress=progress, - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", ViT_B_32_Weights.IMAGENET1K_V1)) -def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: - weights = ViT_B_32_Weights.verify(weights) - - return _vision_transformer( - patch_size=32, - num_layers=12, - num_heads=12, - hidden_dim=768, - mlp_dim=3072, - weights=weights, - progress=progress, - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", ViT_L_16_Weights.IMAGENET1K_V1)) -def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: - weights = ViT_L_16_Weights.verify(weights) - - return _vision_transformer( - patch_size=16, - num_layers=24, - num_heads=16, - hidden_dim=1024, - mlp_dim=4096, - weights=weights, - progress=progress, - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", ViT_L_32_Weights.IMAGENET1K_V1)) -def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: - weights = ViT_L_32_Weights.verify(weights) - - return _vision_transformer( - patch_size=32, - num_layers=24, - num_heads=16, - hidden_dim=1024, - mlp_dim=4096, - weights=weights, - progress=progress, - **kwargs, - ) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 2b52a253820..c900fca5754 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -1,5 +1,3 @@ -from torchvision.transforms import InterpolationMode, AutoAugmentPolicy # usort: skip - from . import functional # usort: skip from ._transform import Transform # usort: skip @@ -19,11 +17,4 @@ ) from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace from ._misc import Identity, Normalize, ToDtype, Lambda -from ._presets import ( - ObjectDetectionEval, - ImageClassificationEval, - SemanticSegmentationEval, - VideoClassificationEval, - OpticalFlowEval, -) from ._type_conversion import DecodeImage, LabelToOneHot diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index c451feb9a32..7fc62423ab8 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -4,9 +4,10 @@ import PIL.Image import torch from torchvision.prototype import features -from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, functional as F +from torchvision.prototype.transforms import Transform, functional as F from torchvision.prototype.utils._internal import query_recursively -from torchvision.transforms.functional import pil_to_tensor, to_pil_image +from torchvision.transforms.autoaugment import AutoAugmentPolicy +from torchvision.transforms.functional import pil_to_tensor, to_pil_image, InterpolationMode from ._utils import get_image_dimensions, is_simple_tensor diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 2a965959629..e568c92c866 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -6,8 +6,8 @@ import PIL.Image import torch from torchvision.prototype import features -from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F -from torchvision.transforms.functional import pil_to_tensor +from torchvision.prototype.transforms import Transform, functional as F +from torchvision.transforms.functional import pil_to_tensor, InterpolationMode from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 1bff7a3f2e6..3d6d43f955a 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -4,9 +4,8 @@ import PIL.Image import torch from torchvision.prototype import features -from torchvision.prototype.transforms import InterpolationMode from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP -from torchvision.transforms.functional import pil_modes_mapping, _get_inverse_affine_matrix +from torchvision.transforms.functional import pil_modes_mapping, _get_inverse_affine_matrix, InterpolationMode from ._meta import convert_bounding_box_format, get_dimensions_image_tensor, get_dimensions_image_pil diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py index 147a7f0ff4c..fe5284394cb 100644 --- a/torchvision/prototype/utils/_internal.py +++ b/torchvision/prototype/utils/_internal.py @@ -1,14 +1,11 @@ import collections.abc import difflib -import functools -import inspect import io import mmap import os import os.path import platform import textwrap -import warnings from typing import ( Any, BinaryIO, @@ -36,7 +33,6 @@ "FrozenMapping", "make_repr", "FrozenBunch", - "kwonly_to_pos_or_kw", "fromfile", "ReadOnlyTensorBuffer", "apply_recursively", @@ -140,57 +136,6 @@ def __repr__(self) -> str: return make_repr(type(self).__name__, self.items()) -def kwonly_to_pos_or_kw(fn: Callable[..., D]) -> Callable[..., D]: - """Decorates a function that uses keyword only parameters to also allow them being passed as positionals. - - For example, consider the use case of changing the signature of ``old_fn`` into the one from ``new_fn``: - - .. code:: - - def old_fn(foo, bar, baz=None): - ... - - def new_fn(foo, *, bar, baz=None): - ... - - Calling ``old_fn("foo", "bar, "baz")`` was valid, but the same call is no longer valid with ``new_fn``. To keep BC - and at the same time warn the user of the deprecation, this decorator can be used: - - .. code:: - - @kwonly_to_pos_or_kw - def new_fn(foo, *, bar, baz=None): - ... - - new_fn("foo", "bar, "baz") - """ - params = inspect.signature(fn).parameters - - try: - keyword_only_start_idx = next( - idx for idx, param in enumerate(params.values()) if param.kind == param.KEYWORD_ONLY - ) - except StopIteration: - raise TypeError(f"Found no keyword-only parameter on function '{fn.__name__}'") from None - - keyword_only_params = tuple(inspect.signature(fn).parameters)[keyword_only_start_idx:] - - @functools.wraps(fn) - def wrapper(*args: Any, **kwargs: Any) -> D: - args, keyword_only_args = args[:keyword_only_start_idx], args[keyword_only_start_idx:] - if keyword_only_args: - keyword_only_kwargs = dict(zip(keyword_only_params, keyword_only_args)) - warnings.warn( - f"Using {sequence_to_str(tuple(keyword_only_kwargs.keys()), separate_last='and ')} as positional " - f"parameter(s) is deprecated. Please use keyword parameter(s) instead." - ) - kwargs.update(keyword_only_kwargs) - - return fn(*args, **kwargs) - - return wrapper - - def _read_mutable_buffer_fallback(file: BinaryIO, count: int, item_size: int) -> bytearray: # A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable return bytearray(file.read(-1 if count == -1 else count * item_size)) diff --git a/torchvision/transforms/__init__.py b/torchvision/transforms/__init__.py index 77680a14f0d..94ec34ebe98 100644 --- a/torchvision/transforms/__init__.py +++ b/torchvision/transforms/__init__.py @@ -1,2 +1,9 @@ from .transforms import * from .autoaugment import * +from ._presets import ( + ObjectDetectionEval, + ImageClassificationEval, + SemanticSegmentationEval, + VideoClassificationEval, + OpticalFlowEval, +) diff --git a/torchvision/prototype/transforms/_presets.py b/torchvision/transforms/_presets.py similarity index 98% rename from torchvision/prototype/transforms/_presets.py rename to torchvision/transforms/_presets.py index 3ab045b3ddb..a6b85d05597 100644 --- a/torchvision/prototype/transforms/_presets.py +++ b/torchvision/transforms/_presets.py @@ -3,7 +3,7 @@ import torch from torch import Tensor, nn -from ...transforms import functional as F, InterpolationMode +from . import functional as F, InterpolationMode __all__ = [