Skip to content

[RFC] Registration mechanism for models #6330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
2 changes: 1 addition & 1 deletion torchvision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
from .vision_transformer import *
from .swin_transformer import *
from . import detection, optical_flow, quantization, segmentation, video
from ._api import get_weight
from ._api import get_model, get_model_weights, get_weight, list_models
84 changes: 80 additions & 4 deletions torchvision/models/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
import sys
from dataclasses import dataclass, fields
from inspect import signature
from typing import Any, Callable, cast, Dict, Mapping
from types import ModuleType
from typing import Any, Callable, cast, Dict, List, Mapping, Optional, Type, TypeVar, Union

from torch import nn

from torchvision._utils import StrEnum

from .._internally_replaced_utils import load_state_dict_from_url


__all__ = ["WeightsEnum", "Weights", "get_weight"]
__all__ = ["WeightsEnum", "Weights", "get_model", "get_model_weights", "get_weight", "list_models"]


@dataclass
Expand Down Expand Up @@ -75,7 +78,7 @@ def __getattr__(self, name):

def get_weight(name: str) -> WeightsEnum:
"""
Gets the weight enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1"
Gets the weights enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1"

Args:
name (str): The name of the weight enum entry.
Expand Down Expand Up @@ -107,10 +110,27 @@ def get_weight(name: str) -> WeightsEnum:
return weights_enum.from_str(value_name)


W = TypeVar("W", bound=Type[WeightsEnum])


def get_model_weights(model: Union[Callable, str]) -> W:
"""
Retuns the weights enum class associated to the given model.

Args:
name (callable or str): The model builder function or the name under which it is registered.

Returns:
weights_enum (W): The weights enum class associated with the model.
"""
if isinstance(model, str):
model = find_model(model)
return cast(W, _get_enum_from_fn(model))


def _get_enum_from_fn(fn: Callable) -> WeightsEnum:
"""
Internal method that gets the weight enum of a specific model builder method.
Might be removed after the handle_legacy_interface is removed.

Args:
fn (Callable): The builder method used to create the model.
Expand Down Expand Up @@ -140,3 +160,59 @@ def _get_enum_from_fn(fn: Callable) -> WeightsEnum:
)

return cast(WeightsEnum, weights_enum)


M = TypeVar("M", bound=Type[nn.Module])

BUILTIN_MODELS = {}


def register_model(name: Optional[str] = None) -> Callable[[Callable[..., M]], Callable[..., M]]:
def wrapper(fn: Callable[..., M]) -> Callable[..., M]:
key = name if name is not None else fn.__name__
if key in BUILTIN_MODELS:
raise ValueError(f"An entry is already registered under the name '{key}'.")
BUILTIN_MODELS[key] = fn
return fn

return wrapper


def list_models(module: Optional[ModuleType] = None) -> List[str]:
"""
Returns a list with the names of registered models.

Args:
module (ModuleType, optional): The module from which we want to extract the available models.

Returns:
models (list): A list with the names of available models.
"""
models = [
k for k, v in BUILTIN_MODELS.items() if module is None or v.__module__.rsplit(".", 1)[0] == module.__name__
]
return sorted(models)


def find_model(name: str) -> Callable[..., M]:
name = name.lower()
try:
fn = BUILTIN_MODELS[name]
except KeyError:
raise ValueError(f"Unknown model {name}")
return fn


def get_model(name: str, **config: Any) -> M:
"""
Gets the model name and configuration and returns an instantiated model.

Args:
name (str): The name under which the model is registered.
**config (Any): parameters passed to the model builder method.

Returns:
model (nn.Module): The initialized model.
"""
fn = find_model(name)
return fn(**config)
3 changes: 2 additions & 1 deletion torchvision/models/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once
from ._api import Weights, WeightsEnum
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param, handle_legacy_interface

Expand Down Expand Up @@ -75,6 +75,7 @@ class AlexNet_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V1


@register_model()
@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1))
def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
"""AlexNet model architecture from `One weird trick for parallelizing convolutional neural networks <https://arxiv.org/abs/1404.5997>`__.
Expand Down
4 changes: 3 additions & 1 deletion torchvision/models/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer
from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once
from ._api import Weights, WeightsEnum
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface

Expand Down Expand Up @@ -371,6 +371,7 @@ class MobileNet_V3_Small_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V1


@register_model()
@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Large_Weights.IMAGENET1K_V1))
def mobilenet_v3_large(
*, weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any
Expand Down Expand Up @@ -401,6 +402,7 @@ def mobilenet_v3_large(
return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs)


@register_model()
@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Small_Weights.IMAGENET1K_V1))
def mobilenet_v3_small(
*, weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any
Expand Down
3 changes: 2 additions & 1 deletion torchvision/models/quantization/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ...ops.misc import Conv2dNormActivation, SqueezeExcitation
from ...transforms._presets import ImageClassification
from .._api import Weights, WeightsEnum
from .._api import register_model, Weights, WeightsEnum
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _ovewrite_named_param, handle_legacy_interface
from ..mobilenetv3 import (
Expand Down Expand Up @@ -184,6 +184,7 @@ class MobileNet_V3_Large_QuantizedWeights(WeightsEnum):
DEFAULT = IMAGENET1K_QNNPACK_V1


@register_model(name="quantized_mobilenet_v3_large")
@handle_legacy_interface(
weights=(
"pretrained",
Expand Down