|
| 1 | +import functools |
1 | 2 | import warnings
|
2 |
| -from typing import Any, Dict, Optional, TypeVar |
| 3 | +from typing import Any, Dict, Optional, TypeVar, Callable, Tuple, Union |
3 | 4 |
|
4 |
| -from ._api import WeightsEnum |
| 5 | +from torch import nn |
| 6 | +from torchvision.prototype.utils._internal import kwonly_to_pos_or_kw |
5 | 7 |
|
| 8 | +from ._api import WeightsEnum |
6 | 9 |
|
7 | 10 | W = TypeVar("W", bound=WeightsEnum)
|
| 11 | +M = TypeVar("M", bound=nn.Module) |
8 | 12 | V = TypeVar("V")
|
9 | 13 |
|
10 | 14 |
|
11 |
| -def _deprecated_param( |
12 |
| - kwargs: Dict[str, Any], deprecated_param: str, new_param: str, default_value: Optional[W] |
13 |
| -) -> Optional[W]: |
14 |
| - warnings.warn(f"The parameter '{deprecated_param}' is deprecated, please use '{new_param}' instead.") |
15 |
| - if kwargs.pop(deprecated_param): |
16 |
| - if default_value is not None: |
17 |
| - return default_value |
18 |
| - else: |
19 |
| - raise ValueError("No checkpoint is available for model.") |
20 |
| - else: |
21 |
| - return None |
| 15 | +def handle_legacy_interface(**weights: Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): |
| 16 | + """Decorates a model builder with the new interface to make it compatible with the old. |
| 17 | +
|
| 18 | + In particular this handles two things: |
| 19 | +
|
| 20 | + 1. Allows positional parameters again, but emits a deprecation warning in case they are used. See |
| 21 | + :func:`torchvision.prototype.utils._internal.kwonly_to_pos_or_kw` for details. |
| 22 | + 2. Handles the default value change from ``pretrained=False`` to ``weights=None`` and ``pretrained=True`` to |
| 23 | + ``weights=Weights`` and emits a deprecation warning with instructions for the new interface. |
| 24 | +
|
| 25 | + Args: |
| 26 | + **weights (Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): Deprecated parameter |
| 27 | + name and default value for the legacy ``pretrained=True``. The default value can be a callable in which |
| 28 | + case it will be called with a dictionary of the keyword arguments. The only key that is guaranteed to be in |
| 29 | + the dictionary is the deprecated parameter name passed as first element in the tuple. All other parameters |
| 30 | + should be accessed with :meth:`~dict.get`. |
| 31 | + """ |
| 32 | + |
| 33 | + def outer_wrapper(builder: Callable[..., M]) -> Callable[..., M]: |
| 34 | + @kwonly_to_pos_or_kw |
| 35 | + @functools.wraps(builder) |
| 36 | + def inner_wrapper(*args: Any, **kwargs: Any) -> M: |
| 37 | + for weights_param, (pretrained_param, default) in weights.items(): # type: ignore[union-attr] |
| 38 | + # If neither the weights nor the pretrained parameter as passed, or the weights argument already use |
| 39 | + # the new style arguments, there is nothing to do. Note that we cannot use `None` as sentinel for the |
| 40 | + # weight argument, since it is a valid value. |
| 41 | + sentinel = object() |
| 42 | + weights_arg = kwargs.get(weights_param, sentinel) |
| 43 | + if ( |
| 44 | + (weights_param not in kwargs and pretrained_param not in kwargs) |
| 45 | + or isinstance(weights_arg, WeightsEnum) |
| 46 | + or weights_arg is None |
| 47 | + ): |
| 48 | + continue |
| 49 | + |
| 50 | + # If the pretrained parameter was passed as positional argument, it is now mapped to |
| 51 | + # `kwargs[weights_param]`. This happens because the @kwonly_to_pos_or_kw decorator uses the current |
| 52 | + # signature to infer the names of positionally passed arguments and thus has no knowledge that there |
| 53 | + # used to be a pretrained parameter. |
| 54 | + pretrained_positional = weights_arg is not sentinel |
| 55 | + if pretrained_positional: |
| 56 | + # We put the pretrained argument under its legacy name in the keyword argument dictionary to have a |
| 57 | + # unified access to the value if the default value is a callable. |
| 58 | + kwargs[pretrained_param] = pretrained_arg = kwargs.pop(weights_param) |
| 59 | + else: |
| 60 | + pretrained_arg = kwargs[pretrained_param] |
| 61 | + |
| 62 | + if pretrained_arg: |
| 63 | + default_weights_arg = default(kwargs) if callable(default) else default |
| 64 | + if not isinstance(default_weights_arg, WeightsEnum): |
| 65 | + raise ValueError(f"No weights available for model {builder.__name__}") |
| 66 | + else: |
| 67 | + default_weights_arg = None |
| 68 | + |
| 69 | + if not pretrained_positional: |
| 70 | + warnings.warn( |
| 71 | + f"The parameter '{pretrained_param}' is deprecated, please use '{weights_param}' instead." |
| 72 | + ) |
| 73 | + |
| 74 | + msg = ( |
| 75 | + f"Arguments other than a weight enum or `None` for '{weights_param}' are deprecated. " |
| 76 | + f"The current behavior is equivalent to passing `{weights_param}={default_weights_arg}`." |
| 77 | + ) |
| 78 | + if pretrained_arg: |
| 79 | + msg = ( |
| 80 | + f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.default` " |
| 81 | + f"to get the most up-to-date weights." |
| 82 | + ) |
| 83 | + warnings.warn(msg) |
| 84 | + |
| 85 | + del kwargs[pretrained_param] |
| 86 | + kwargs[weights_param] = default_weights_arg |
| 87 | + |
| 88 | + return builder(*args, **kwargs) |
22 | 89 |
|
| 90 | + return inner_wrapper |
23 | 91 |
|
24 |
| -def _deprecated_positional(kwargs: Dict[str, Any], deprecated_param: str, new_param: str, default_value: V) -> None: |
25 |
| - warnings.warn( |
26 |
| - f"The positional parameter '{deprecated_param}' is deprecated, please use keyword parameter '{new_param}'" |
27 |
| - + " instead." |
28 |
| - ) |
29 |
| - kwargs[deprecated_param] = default_value |
| 92 | + return outer_wrapper |
30 | 93 |
|
31 | 94 |
|
32 | 95 | def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> None:
|
|
0 commit comments