From 0d62aaf8baea3a96901f9f04e0509315c16c875c Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 28 Jul 2022 18:40:40 +0100 Subject: [PATCH 01/14] Model registration mechanism. --- torchvision/models/__init__.py | 2 +- torchvision/models/_api.py | 51 +++++++++++++++++++++++++++++++++- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 7bca0276c34..698977552ad 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -14,4 +14,4 @@ from .vision_transformer import * from .swin_transformer import * from . import detection, optical_flow, quantization, segmentation, video -from ._api import get_weight +from ._api import get_weight, list_models, load diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 901bb0015e4..e8f88a8a930 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -3,7 +3,9 @@ import sys from dataclasses import dataclass, fields from inspect import signature -from typing import Any, Callable, cast, Dict, Mapping +from torch import nn +from types import ModuleType +from typing import Any, Callable, cast, Dict, List, Mapping, Optional, Type, TypeVar from torchvision._utils import StrEnum @@ -140,3 +142,50 @@ def _get_enum_from_fn(fn: Callable) -> WeightsEnum: ) return cast(WeightsEnum, weights_enum) + + +M = TypeVar("M", bound=Type[nn.Module]) + +BUILTIN_MODELS = {} + + +def register_model(name: str, overwrite: bool = False) -> Callable[[Callable[..., M]], Callable[..., M]]: + def wrapper(fn: Callable[..., M]) -> Callable[..., M]: + if name in BUILTIN_MODELS and not overwrite: + raise ValueError(f"A model is already registered under tha name '{name}'.") + BUILTIN_MODELS[name] = fn + return fn + return wrapper + + +def list_models(module: Optional[ModuleType] = None) -> List[str]: + """ + Returns a list with the names of registred models. + + Args: + module (ModuleType, optional): The module from which we want to extract the available models. + + Returns: + models (list): A list with the names of available models. + """ + models = [k for k, v in BUILTIN_MODELS.items() if module is None or v.__module__ == module] + return sorted(models) + + +def load(name: str, **config: Any) -> M: + """ + Gets the model name and configuration and returns an instantiated model. + + Args: + name (str): The name under which the model is registered. + **config (Any): parameters passed to the model builder method. + + Returns: + model (nn.Module): The initialized model. + """ + name = name.lower() + try: + fn = BUILTIN_MODELS[name] + except KeyError: + raise ValueError(f"Unknown model {name}") + return fn(**config) From 0e7eb8a0dc1e013d3ffeeccc4520d5c8090fbbfa Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 28 Jul 2022 18:42:58 +0100 Subject: [PATCH 02/14] Add overwrite options to the dataset prototype registration mechanism. --- torchvision/models/_api.py | 2 +- torchvision/prototype/datasets/_api.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index e8f88a8a930..a5422d08443 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -152,7 +152,7 @@ def _get_enum_from_fn(fn: Callable) -> WeightsEnum: def register_model(name: str, overwrite: bool = False) -> Callable[[Callable[..., M]], Callable[..., M]]: def wrapper(fn: Callable[..., M]) -> Callable[..., M]: if name in BUILTIN_MODELS and not overwrite: - raise ValueError(f"A model is already registered under tha name '{name}'.") + raise ValueError(f"An entry is already registered under the name '{name}'.") BUILTIN_MODELS[name] = fn return fn return wrapper diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index f6f06c60a21..c3f40c448ca 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -12,8 +12,10 @@ BUILTIN_INFOS: Dict[str, Dict[str, Any]] = {} -def register_info(name: str) -> Callable[[Callable[[], Dict[str, Any]]], Callable[[], Dict[str, Any]]]: +def register_info(name: str, overwrite: bool = False) -> Callable[[Callable[[], Dict[str, Any]]], Callable[[], Dict[str, Any]]]: def wrapper(fn: Callable[[], Dict[str, Any]]) -> Callable[[], Dict[str, Any]]: + if name in BUILTIN_INFOS and not overwrite: + raise ValueError(f"An entry is already registered under the name '{name}'.") BUILTIN_INFOS[name] = fn() return fn @@ -23,8 +25,10 @@ def wrapper(fn: Callable[[], Dict[str, Any]]) -> Callable[[], Dict[str, Any]]: BUILTIN_DATASETS = {} -def register_dataset(name: str) -> Callable[[D], D]: +def register_dataset(name: str, overwrite: bool = False) -> Callable[[D], D]: def wrapper(dataset_cls: D) -> D: + if name in BUILTIN_DATASETS and not overwrite: + raise ValueError(f"An entry is already registered under the name '{name}'.") BUILTIN_DATASETS[name] = dataset_cls return dataset_cls From 1520566e7a73001e3610a97291437d53d1520c05 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 28 Jul 2022 18:43:26 +0100 Subject: [PATCH 03/14] Adding example models. --- torchvision/models/alexnet.py | 3 ++- torchvision/models/mobilenetv3.py | 4 +++- torchvision/models/quantization/mobilenetv3.py | 3 ++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/torchvision/models/alexnet.py b/torchvision/models/alexnet.py index 5d1401dcb36..0b1bd8b1310 100644 --- a/torchvision/models/alexnet.py +++ b/torchvision/models/alexnet.py @@ -6,7 +6,7 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import Weights, WeightsEnum +from ._api import Weights, WeightsEnum, register_model from ._meta import _IMAGENET_CATEGORIES from ._utils import _ovewrite_named_param, handle_legacy_interface @@ -75,6 +75,7 @@ class AlexNet_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V1 +@register_model("alexnet") @handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1)) def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet: """AlexNet model architecture from `One weird trick for parallelizing convolutional neural networks `__. diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index 10d2a1c91ac..9ac8c99050e 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -8,7 +8,7 @@ from ..ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import Weights, WeightsEnum +from ._api import Weights, WeightsEnum, register_model from ._meta import _IMAGENET_CATEGORIES from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface @@ -371,6 +371,7 @@ class MobileNet_V3_Small_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V1 +@register_model("mobilenet_v3_large") @handle_legacy_interface(weights=("pretrained", MobileNet_V3_Large_Weights.IMAGENET1K_V1)) def mobilenet_v3_large( *, weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any @@ -401,6 +402,7 @@ def mobilenet_v3_large( return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs) +@register_model("mobilenet_v3_small") @handle_legacy_interface(weights=("pretrained", MobileNet_V3_Small_Weights.IMAGENET1K_V1)) def mobilenet_v3_small( *, weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 56341bb280e..835fd239453 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -7,7 +7,7 @@ from ...ops.misc import Conv2dNormActivation, SqueezeExcitation from ...transforms._presets import ImageClassification -from .._api import Weights, WeightsEnum +from .._api import Weights, WeightsEnum, register_model from .._meta import _IMAGENET_CATEGORIES from .._utils import _ovewrite_named_param, handle_legacy_interface from ..mobilenetv3 import ( @@ -184,6 +184,7 @@ class MobileNet_V3_Large_QuantizedWeights(WeightsEnum): DEFAULT = IMAGENET1K_QNNPACK_V1 +@register_model("quantized_mobilenet_v3_large") @handle_legacy_interface( weights=( "pretrained", From 2e16077422a89e91067a9c89f5caa053cdb82e88 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 08:31:36 +0100 Subject: [PATCH 04/14] Fix module filtering --- torchvision/models/_api.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index a5422d08443..74940a9c0e7 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -168,7 +168,9 @@ def list_models(module: Optional[ModuleType] = None) -> List[str]: Returns: models (list): A list with the names of available models. """ - models = [k for k, v in BUILTIN_MODELS.items() if module is None or v.__module__ == module] + models = [ + k for k, v in BUILTIN_MODELS.items() if module is None or v.__module__.rsplit(".", 1)[0] == module.__name__ + ] return sorted(models) From a02c124fb3b845fa387518081be16d365e87e37c Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 08:41:25 +0100 Subject: [PATCH 05/14] Fix linter --- torchvision/models/_api.py | 4 +++- torchvision/models/alexnet.py | 2 +- torchvision/models/mobilenetv3.py | 2 +- torchvision/models/quantization/mobilenetv3.py | 2 +- torchvision/prototype/datasets/_api.py | 4 +++- 5 files changed, 9 insertions(+), 5 deletions(-) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 74940a9c0e7..713dc9ece1f 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -3,10 +3,11 @@ import sys from dataclasses import dataclass, fields from inspect import signature -from torch import nn from types import ModuleType from typing import Any, Callable, cast, Dict, List, Mapping, Optional, Type, TypeVar +from torch import nn + from torchvision._utils import StrEnum from .._internally_replaced_utils import load_state_dict_from_url @@ -155,6 +156,7 @@ def wrapper(fn: Callable[..., M]) -> Callable[..., M]: raise ValueError(f"An entry is already registered under the name '{name}'.") BUILTIN_MODELS[name] = fn return fn + return wrapper diff --git a/torchvision/models/alexnet.py b/torchvision/models/alexnet.py index 0b1bd8b1310..de12bb415e6 100644 --- a/torchvision/models/alexnet.py +++ b/torchvision/models/alexnet.py @@ -6,7 +6,7 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import Weights, WeightsEnum, register_model +from ._api import register_model, Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES from ._utils import _ovewrite_named_param, handle_legacy_interface diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index 9ac8c99050e..e2b571d820a 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -8,7 +8,7 @@ from ..ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import Weights, WeightsEnum, register_model +from ._api import register_model, Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 835fd239453..bc916ad6d02 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -7,7 +7,7 @@ from ...ops.misc import Conv2dNormActivation, SqueezeExcitation from ...transforms._presets import ImageClassification -from .._api import Weights, WeightsEnum, register_model +from .._api import register_model, Weights, WeightsEnum from .._meta import _IMAGENET_CATEGORIES from .._utils import _ovewrite_named_param, handle_legacy_interface from ..mobilenetv3 import ( diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index c3f40c448ca..8e237aa3897 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -12,7 +12,9 @@ BUILTIN_INFOS: Dict[str, Dict[str, Any]] = {} -def register_info(name: str, overwrite: bool = False) -> Callable[[Callable[[], Dict[str, Any]]], Callable[[], Dict[str, Any]]]: +def register_info( + name: str, overwrite: bool = False +) -> Callable[[Callable[[], Dict[str, Any]]], Callable[[], Dict[str, Any]]]: def wrapper(fn: Callable[[], Dict[str, Any]]) -> Callable[[], Dict[str, Any]]: if name in BUILTIN_INFOS and not overwrite: raise ValueError(f"An entry is already registered under the name '{name}'.") From eedf8dfcb4c69adcfbcaa736f9e6e910567784ed Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 08:42:38 +0100 Subject: [PATCH 06/14] Fix docs --- torchvision/models/_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 713dc9ece1f..111a0afde06 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -162,7 +162,7 @@ def wrapper(fn: Callable[..., M]) -> Callable[..., M]: def list_models(module: Optional[ModuleType] = None) -> List[str]: """ - Returns a list with the names of registred models. + Returns a list with the names of registered models. Args: module (ModuleType, optional): The module from which we want to extract the available models. From a91a5b4aabca86a5911ffc94ea03eb554514859e Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 09:29:34 +0100 Subject: [PATCH 07/14] Make name optional if same as model builder --- torchvision/models/_api.py | 9 +++++---- torchvision/models/alexnet.py | 2 +- torchvision/models/mobilenetv3.py | 4 ++-- torchvision/models/quantization/mobilenetv3.py | 2 +- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 111a0afde06..5fa6b95d99c 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -150,11 +150,12 @@ def _get_enum_from_fn(fn: Callable) -> WeightsEnum: BUILTIN_MODELS = {} -def register_model(name: str, overwrite: bool = False) -> Callable[[Callable[..., M]], Callable[..., M]]: +def register_model(name: Optional[str] = None, overwrite: bool = False) -> Callable[[Callable[..., M]], Callable[..., M]]: def wrapper(fn: Callable[..., M]) -> Callable[..., M]: - if name in BUILTIN_MODELS and not overwrite: - raise ValueError(f"An entry is already registered under the name '{name}'.") - BUILTIN_MODELS[name] = fn + key = name if name is not None else fn.__name__ + if key in BUILTIN_MODELS and not overwrite: + raise ValueError(f"An entry is already registered under the name '{key}'.") + BUILTIN_MODELS[key] = fn return fn return wrapper diff --git a/torchvision/models/alexnet.py b/torchvision/models/alexnet.py index de12bb415e6..328f978ba11 100644 --- a/torchvision/models/alexnet.py +++ b/torchvision/models/alexnet.py @@ -75,7 +75,7 @@ class AlexNet_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V1 -@register_model("alexnet") +@register_model() @handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1)) def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet: """AlexNet model architecture from `One weird trick for parallelizing convolutional neural networks `__. diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index e2b571d820a..81fc3c5d4c0 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -371,7 +371,7 @@ class MobileNet_V3_Small_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V1 -@register_model("mobilenet_v3_large") +@register_model() @handle_legacy_interface(weights=("pretrained", MobileNet_V3_Large_Weights.IMAGENET1K_V1)) def mobilenet_v3_large( *, weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any @@ -402,7 +402,7 @@ def mobilenet_v3_large( return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs) -@register_model("mobilenet_v3_small") +@register_model() @handle_legacy_interface(weights=("pretrained", MobileNet_V3_Small_Weights.IMAGENET1K_V1)) def mobilenet_v3_small( *, weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index bc916ad6d02..986f67c6080 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -184,7 +184,7 @@ class MobileNet_V3_Large_QuantizedWeights(WeightsEnum): DEFAULT = IMAGENET1K_QNNPACK_V1 -@register_model("quantized_mobilenet_v3_large") +@register_model(name="quantized_mobilenet_v3_large") @handle_legacy_interface( weights=( "pretrained", From abbe23e83351f05b1f6cf3d982b266e4cbbaf251 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 10:03:22 +0100 Subject: [PATCH 08/14] Apply updates from code-review. --- torchvision/models/_api.py | 6 +++--- torchvision/prototype/datasets/_api.py | 10 ++-------- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 5fa6b95d99c..c7d99c87833 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -150,10 +150,10 @@ def _get_enum_from_fn(fn: Callable) -> WeightsEnum: BUILTIN_MODELS = {} -def register_model(name: Optional[str] = None, overwrite: bool = False) -> Callable[[Callable[..., M]], Callable[..., M]]: +def register_model(name: Optional[str] = None) -> Callable[[Callable[..., M]], Callable[..., M]]: def wrapper(fn: Callable[..., M]) -> Callable[..., M]: key = name if name is not None else fn.__name__ - if key in BUILTIN_MODELS and not overwrite: + if key in BUILTIN_MODELS: raise ValueError(f"An entry is already registered under the name '{key}'.") BUILTIN_MODELS[key] = fn return fn @@ -177,7 +177,7 @@ def list_models(module: Optional[ModuleType] = None) -> List[str]: return sorted(models) -def load(name: str, **config: Any) -> M: +def load_model(name: str, **config: Any) -> M: """ Gets the model name and configuration and returns an instantiated model. diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index 8e237aa3897..f6f06c60a21 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -12,12 +12,8 @@ BUILTIN_INFOS: Dict[str, Dict[str, Any]] = {} -def register_info( - name: str, overwrite: bool = False -) -> Callable[[Callable[[], Dict[str, Any]]], Callable[[], Dict[str, Any]]]: +def register_info(name: str) -> Callable[[Callable[[], Dict[str, Any]]], Callable[[], Dict[str, Any]]]: def wrapper(fn: Callable[[], Dict[str, Any]]) -> Callable[[], Dict[str, Any]]: - if name in BUILTIN_INFOS and not overwrite: - raise ValueError(f"An entry is already registered under the name '{name}'.") BUILTIN_INFOS[name] = fn() return fn @@ -27,10 +23,8 @@ def wrapper(fn: Callable[[], Dict[str, Any]]) -> Callable[[], Dict[str, Any]]: BUILTIN_DATASETS = {} -def register_dataset(name: str, overwrite: bool = False) -> Callable[[D], D]: +def register_dataset(name: str) -> Callable[[D], D]: def wrapper(dataset_cls: D) -> D: - if name in BUILTIN_DATASETS and not overwrite: - raise ValueError(f"An entry is already registered under the name '{name}'.") BUILTIN_DATASETS[name] = dataset_cls return dataset_cls From 1eb8159b13a737c1800f8b0143fb1c9b83b14d8c Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 10:05:59 +0100 Subject: [PATCH 09/14] fix minor bug --- torchvision/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 698977552ad..5c28db4660c 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -14,4 +14,4 @@ from .vision_transformer import * from .swin_transformer import * from . import detection, optical_flow, quantization, segmentation, video -from ._api import get_weight, list_models, load +from ._api import get_weight, list_models, load_model From 924388efe8b3916c828c1e98bfa23cd1d8331d5e Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 10:26:51 +0100 Subject: [PATCH 10/14] Adding getter for model weight enum --- torchvision/models/__init__.py | 2 +- torchvision/models/_api.py | 41 +++++++++++++++++++++++++--------- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 5c28db4660c..7dee4c389f8 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -14,4 +14,4 @@ from .vision_transformer import * from .swin_transformer import * from . import detection, optical_flow, quantization, segmentation, video -from ._api import get_weight, list_models, load_model +from ._api import get_weight, list_models, get_model, get_model_weight diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index c7d99c87833..f85fe0599b3 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -13,7 +13,7 @@ from .._internally_replaced_utils import load_state_dict_from_url -__all__ = ["WeightsEnum", "Weights", "get_weight"] +__all__ = ["WeightsEnum", "Weights", "get_weight", "list_models", "get_model", "get_model_weight"] @dataclass @@ -110,10 +110,26 @@ def get_weight(name: str) -> WeightsEnum: return weights_enum.from_str(value_name) -def _get_enum_from_fn(fn: Callable) -> WeightsEnum: +W = TypeVar("W", bound=Type[WeightsEnum]) + + +def get_model_weight(name: str) -> W: + """ + Retuns the Weights Enum from the model name. + + Args: + name (str): The name under which the model is registered. + + Returns: + weights_enum (W): The weights enum class associated with the model. + """ + fn = _find_model(name) + return _get_enum_from_fn(fn) + + +def _get_enum_from_fn(fn: Callable) -> W: """ Internal method that gets the weight enum of a specific model builder method. - Might be removed after the handle_legacy_interface is removed. Args: fn (Callable): The builder method used to create the model. @@ -142,7 +158,7 @@ def _get_enum_from_fn(fn: Callable) -> WeightsEnum: "The WeightsEnum class for the specific method couldn't be retrieved. Make sure the typing info is correct." ) - return cast(WeightsEnum, weights_enum) + return cast(W, weights_enum) M = TypeVar("M", bound=Type[nn.Module]) @@ -177,7 +193,16 @@ def list_models(module: Optional[ModuleType] = None) -> List[str]: return sorted(models) -def load_model(name: str, **config: Any) -> M: +def _find_model(name: str) -> Callable[..., M]: + name = name.lower() + try: + fn = BUILTIN_MODELS[name] + except KeyError: + raise ValueError(f"Unknown model {name}") + return fn + + +def get_model(name: str, **config: Any) -> M: """ Gets the model name and configuration and returns an instantiated model. @@ -188,9 +213,5 @@ def load_model(name: str, **config: Any) -> M: Returns: model (nn.Module): The initialized model. """ - name = name.lower() - try: - fn = BUILTIN_MODELS[name] - except KeyError: - raise ValueError(f"Unknown model {name}") + fn = _find_model(name) return fn(**config) From bd2327a5666f723753f70d411e319bdb503a2f8e Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 10:51:58 +0100 Subject: [PATCH 11/14] Support both strings and callables on get_model_weight. --- torchvision/models/_api.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index f85fe0599b3..b09b5a5e7e4 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, fields from inspect import signature from types import ModuleType -from typing import Any, Callable, cast, Dict, List, Mapping, Optional, Type, TypeVar +from typing import Any, Callable, cast, Dict, List, Mapping, Optional, Type, TypeVar, Union from torch import nn @@ -113,18 +113,19 @@ def get_weight(name: str) -> WeightsEnum: W = TypeVar("W", bound=Type[WeightsEnum]) -def get_model_weight(name: str) -> W: +def get_model_weight(model: Union[Callable, str]) -> W: """ - Retuns the Weights Enum from the model name. + Retuns the Weights Enum of a model. Args: - name (str): The name under which the model is registered. + name (callable or str): The model builder function or the name under which it is registered. Returns: weights_enum (W): The weights enum class associated with the model. """ - fn = _find_model(name) - return _get_enum_from_fn(fn) + if isinstance(model, str): + model = _find_model(model) + return _get_enum_from_fn(model) def _get_enum_from_fn(fn: Callable) -> W: From a815a633394f02fba7c04f2f374d17fd98c8eff3 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 10:57:33 +0100 Subject: [PATCH 12/14] linter fixes --- torchvision/models/__init__.py | 2 +- torchvision/models/_api.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 7dee4c389f8..95f30b44872 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -14,4 +14,4 @@ from .vision_transformer import * from .swin_transformer import * from . import detection, optical_flow, quantization, segmentation, video -from ._api import get_weight, list_models, get_model, get_model_weight +from ._api import get_model, get_model_weight, get_weight, list_models diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index b09b5a5e7e4..66fdb652475 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -13,7 +13,7 @@ from .._internally_replaced_utils import load_state_dict_from_url -__all__ = ["WeightsEnum", "Weights", "get_weight", "list_models", "get_model", "get_model_weight"] +__all__ = ["WeightsEnum", "Weights", "get_model", "get_model_weight", "get_weight", "list_models"] @dataclass @@ -124,7 +124,7 @@ def get_model_weight(model: Union[Callable, str]) -> W: weights_enum (W): The weights enum class associated with the model. """ if isinstance(model, str): - model = _find_model(model) + model = find_model(model) return _get_enum_from_fn(model) @@ -194,7 +194,7 @@ def list_models(module: Optional[ModuleType] = None) -> List[str]: return sorted(models) -def _find_model(name: str) -> Callable[..., M]: +def find_model(name: str) -> Callable[..., M]: name = name.lower() try: fn = BUILTIN_MODELS[name] @@ -214,5 +214,5 @@ def get_model(name: str, **config: Any) -> M: Returns: model (nn.Module): The initialized model. """ - fn = _find_model(name) + fn = find_model(name) return fn(**config) From 9e4e62cf37cec8eaaf59b59d42727bffd9abe4fd Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 11:17:54 +0100 Subject: [PATCH 13/14] Fixing mypy. --- torchvision/models/_api.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 66fdb652475..89436bced9a 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -125,10 +125,10 @@ def get_model_weight(model: Union[Callable, str]) -> W: """ if isinstance(model, str): model = find_model(model) - return _get_enum_from_fn(model) + return cast(W, _get_enum_from_fn(model)) -def _get_enum_from_fn(fn: Callable) -> W: +def _get_enum_from_fn(fn: Callable) -> WeightsEnum: """ Internal method that gets the weight enum of a specific model builder method. @@ -159,7 +159,7 @@ def _get_enum_from_fn(fn: Callable) -> W: "The WeightsEnum class for the specific method couldn't be retrieved. Make sure the typing info is correct." ) - return cast(W, weights_enum) + return cast(WeightsEnum, weights_enum) M = TypeVar("M", bound=Type[nn.Module]) From 020932708a31eb666b487d6a9eb1ff2ac9229170 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jul 2022 14:54:02 +0100 Subject: [PATCH 14/14] Renaming `get_model_weight` to `get_model_weights` --- torchvision/models/__init__.py | 2 +- torchvision/models/_api.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 95f30b44872..eb949fb3d5c 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -14,4 +14,4 @@ from .vision_transformer import * from .swin_transformer import * from . import detection, optical_flow, quantization, segmentation, video -from ._api import get_model, get_model_weight, get_weight, list_models +from ._api import get_model, get_model_weights, get_weight, list_models diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 89436bced9a..3f4e9e914af 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -13,7 +13,7 @@ from .._internally_replaced_utils import load_state_dict_from_url -__all__ = ["WeightsEnum", "Weights", "get_model", "get_model_weight", "get_weight", "list_models"] +__all__ = ["WeightsEnum", "Weights", "get_model", "get_model_weights", "get_weight", "list_models"] @dataclass @@ -78,7 +78,7 @@ def __getattr__(self, name): def get_weight(name: str) -> WeightsEnum: """ - Gets the weight enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1" + Gets the weights enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1" Args: name (str): The name of the weight enum entry. @@ -113,9 +113,9 @@ def get_weight(name: str) -> WeightsEnum: W = TypeVar("W", bound=Type[WeightsEnum]) -def get_model_weight(model: Union[Callable, str]) -> W: +def get_model_weights(model: Union[Callable, str]) -> W: """ - Retuns the Weights Enum of a model. + Retuns the weights enum class associated to the given model. Args: name (callable or str): The model builder function or the name under which it is registered.