diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 7bca0276c34..eb949fb3d5c 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -14,4 +14,4 @@ from .vision_transformer import * from .swin_transformer import * from . import detection, optical_flow, quantization, segmentation, video -from ._api import get_weight +from ._api import get_model, get_model_weights, get_weight, list_models diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 901bb0015e4..3f4e9e914af 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -3,14 +3,17 @@ import sys from dataclasses import dataclass, fields from inspect import signature -from typing import Any, Callable, cast, Dict, Mapping +from types import ModuleType +from typing import Any, Callable, cast, Dict, List, Mapping, Optional, Type, TypeVar, Union + +from torch import nn from torchvision._utils import StrEnum from .._internally_replaced_utils import load_state_dict_from_url -__all__ = ["WeightsEnum", "Weights", "get_weight"] +__all__ = ["WeightsEnum", "Weights", "get_model", "get_model_weights", "get_weight", "list_models"] @dataclass @@ -75,7 +78,7 @@ def __getattr__(self, name): def get_weight(name: str) -> WeightsEnum: """ - Gets the weight enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1" + Gets the weights enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1" Args: name (str): The name of the weight enum entry. @@ -107,10 +110,27 @@ def get_weight(name: str) -> WeightsEnum: return weights_enum.from_str(value_name) +W = TypeVar("W", bound=Type[WeightsEnum]) + + +def get_model_weights(model: Union[Callable, str]) -> W: + """ + Retuns the weights enum class associated to the given model. + + Args: + name (callable or str): The model builder function or the name under which it is registered. + + Returns: + weights_enum (W): The weights enum class associated with the model. + """ + if isinstance(model, str): + model = find_model(model) + return cast(W, _get_enum_from_fn(model)) + + def _get_enum_from_fn(fn: Callable) -> WeightsEnum: """ Internal method that gets the weight enum of a specific model builder method. - Might be removed after the handle_legacy_interface is removed. Args: fn (Callable): The builder method used to create the model. @@ -140,3 +160,59 @@ def _get_enum_from_fn(fn: Callable) -> WeightsEnum: ) return cast(WeightsEnum, weights_enum) + + +M = TypeVar("M", bound=Type[nn.Module]) + +BUILTIN_MODELS = {} + + +def register_model(name: Optional[str] = None) -> Callable[[Callable[..., M]], Callable[..., M]]: + def wrapper(fn: Callable[..., M]) -> Callable[..., M]: + key = name if name is not None else fn.__name__ + if key in BUILTIN_MODELS: + raise ValueError(f"An entry is already registered under the name '{key}'.") + BUILTIN_MODELS[key] = fn + return fn + + return wrapper + + +def list_models(module: Optional[ModuleType] = None) -> List[str]: + """ + Returns a list with the names of registered models. + + Args: + module (ModuleType, optional): The module from which we want to extract the available models. + + Returns: + models (list): A list with the names of available models. + """ + models = [ + k for k, v in BUILTIN_MODELS.items() if module is None or v.__module__.rsplit(".", 1)[0] == module.__name__ + ] + return sorted(models) + + +def find_model(name: str) -> Callable[..., M]: + name = name.lower() + try: + fn = BUILTIN_MODELS[name] + except KeyError: + raise ValueError(f"Unknown model {name}") + return fn + + +def get_model(name: str, **config: Any) -> M: + """ + Gets the model name and configuration and returns an instantiated model. + + Args: + name (str): The name under which the model is registered. + **config (Any): parameters passed to the model builder method. + + Returns: + model (nn.Module): The initialized model. + """ + fn = find_model(name) + return fn(**config) diff --git a/torchvision/models/alexnet.py b/torchvision/models/alexnet.py index 5d1401dcb36..328f978ba11 100644 --- a/torchvision/models/alexnet.py +++ b/torchvision/models/alexnet.py @@ -6,7 +6,7 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import Weights, WeightsEnum +from ._api import register_model, Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES from ._utils import _ovewrite_named_param, handle_legacy_interface @@ -75,6 +75,7 @@ class AlexNet_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V1 +@register_model() @handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1)) def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet: """AlexNet model architecture from `One weird trick for parallelizing convolutional neural networks `__. diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index 10d2a1c91ac..81fc3c5d4c0 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -8,7 +8,7 @@ from ..ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import Weights, WeightsEnum +from ._api import register_model, Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface @@ -371,6 +371,7 @@ class MobileNet_V3_Small_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V1 +@register_model() @handle_legacy_interface(weights=("pretrained", MobileNet_V3_Large_Weights.IMAGENET1K_V1)) def mobilenet_v3_large( *, weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any @@ -401,6 +402,7 @@ def mobilenet_v3_large( return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs) +@register_model() @handle_legacy_interface(weights=("pretrained", MobileNet_V3_Small_Weights.IMAGENET1K_V1)) def mobilenet_v3_small( *, weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 56341bb280e..986f67c6080 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -7,7 +7,7 @@ from ...ops.misc import Conv2dNormActivation, SqueezeExcitation from ...transforms._presets import ImageClassification -from .._api import Weights, WeightsEnum +from .._api import register_model, Weights, WeightsEnum from .._meta import _IMAGENET_CATEGORIES from .._utils import _ovewrite_named_param, handle_legacy_interface from ..mobilenetv3 import ( @@ -184,6 +184,7 @@ class MobileNet_V3_Large_QuantizedWeights(WeightsEnum): DEFAULT = IMAGENET1K_QNNPACK_V1 +@register_model(name="quantized_mobilenet_v3_large") @handle_legacy_interface( weights=( "pretrained",