Skip to content

Commit 9decc18

Browse files
Vincent Moensdatumbox
authored andcommitted
[fbsync] simplify model builders (#5001)
Summary: * simplify model builders * cleanup * refactor kwonly to pos or kw handling * put weight verification back * revert num categories checks * fix default weights * cleanup * remove manual parameter map * refactor decorator interface * address review comments * cleanup * refactor callable default * fix type annotation * process ungrouped models * cleanup * mroe cleanup * use decorator for detection models * add decorator for quantization models * add decorator for segmentation models * add decorator for video models * remove old helpers * fix resnet50 * Adding verification back on InceptionV3 * Add kwargs in DeeplabeV3 * Add kwargs on FCN * Fix typing on Deeplab * Fix typing on FCN Reviewed By: NicolasHug Differential Revision: D32950943 fbshipit-source-id: c2a5c21e48b7e6fffcbba2e1ee7ab9dd7dc4a306 Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent cecea28 commit 9decc18

33 files changed

+579
-644
lines changed

test/test_prototype_models.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import torch
66
from common_utils import cpu_and_gpu, run_on_env_var
77
from torchvision.prototype import models
8+
from torchvision.prototype.models._api import WeightsEnum, Weights
9+
from torchvision.prototype.models._utils import handle_legacy_interface
810

911
run_if_test_with_prototype = run_on_env_var(
1012
"PYTORCH_TEST_WITH_PROTOTYPE",
@@ -164,3 +166,87 @@ def test_old_vs_new_factory(model_fn, dev):
164166

165167
def test_smoke():
166168
import torchvision.prototype.models # noqa: F401
169+
170+
171+
# With this filter, every unexpected warning will be turned into an error
172+
@pytest.mark.filterwarnings("error")
173+
class TestHandleLegacyInterface:
174+
class TestWeights(WeightsEnum):
175+
Sentinel = Weights(url="https://pytorch.org", transforms=lambda x: x, meta=dict())
176+
177+
@pytest.mark.parametrize(
178+
"kwargs",
179+
[
180+
pytest.param(dict(), id="empty"),
181+
pytest.param(dict(weights=None), id="None"),
182+
pytest.param(dict(weights=TestWeights.Sentinel), id="Weights"),
183+
],
184+
)
185+
def test_no_warn(self, kwargs):
186+
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
187+
def builder(*, weights=None):
188+
pass
189+
190+
builder(**kwargs)
191+
192+
@pytest.mark.parametrize("pretrained", (True, False))
193+
def test_pretrained_pos(self, pretrained):
194+
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
195+
def builder(*, weights=None):
196+
pass
197+
198+
with pytest.warns(UserWarning, match="positional"):
199+
builder(pretrained)
200+
201+
@pytest.mark.parametrize("pretrained", (True, False))
202+
def test_pretrained_kw(self, pretrained):
203+
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
204+
def builder(*, weights=None):
205+
pass
206+
207+
with pytest.warns(UserWarning, match="deprecated"):
208+
builder(pretrained)
209+
210+
@pytest.mark.parametrize("pretrained", (True, False))
211+
@pytest.mark.parametrize("positional", (True, False))
212+
def test_equivalent_behavior_weights(self, pretrained, positional):
213+
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
214+
def builder(*, weights=None):
215+
pass
216+
217+
args, kwargs = ((pretrained,), dict()) if positional else ((), dict(pretrained=pretrained))
218+
with pytest.warns(UserWarning, match=f"weights={self.TestWeights.Sentinel if pretrained else None}"):
219+
builder(*args, **kwargs)
220+
221+
def test_multi_params(self):
222+
weights_params = ("weights", "weights_other")
223+
pretrained_params = [param.replace("weights", "pretrained") for param in weights_params]
224+
225+
@handle_legacy_interface(
226+
**{
227+
weights_param: (pretrained_param, self.TestWeights.Sentinel)
228+
for weights_param, pretrained_param in zip(weights_params, pretrained_params)
229+
}
230+
)
231+
def builder(*, weights=None, weights_other=None):
232+
pass
233+
234+
for pretrained_param in pretrained_params:
235+
with pytest.warns(UserWarning, match="deprecated"):
236+
builder(**{pretrained_param: True})
237+
238+
def test_default_callable(self):
239+
@handle_legacy_interface(
240+
weights=(
241+
"pretrained",
242+
lambda kwargs: self.TestWeights.Sentinel if kwargs["flag"] else None,
243+
)
244+
)
245+
def builder(*, weights=None, flag):
246+
pass
247+
248+
with pytest.warns(UserWarning, match="deprecated"):
249+
builder(pretrained=True, flag=True)
250+
251+
with pytest.raises(ValueError, match="weights"):
252+
builder(pretrained=True, flag=False)

torchvision/prototype/models/_utils.py

Lines changed: 82 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,95 @@
1+
import functools
12
import warnings
2-
from typing import Any, Dict, Optional, TypeVar
3+
from typing import Any, Dict, Optional, TypeVar, Callable, Tuple, Union
34

4-
from ._api import WeightsEnum
5+
from torch import nn
6+
from torchvision.prototype.utils._internal import kwonly_to_pos_or_kw
57

8+
from ._api import WeightsEnum
69

710
W = TypeVar("W", bound=WeightsEnum)
11+
M = TypeVar("M", bound=nn.Module)
812
V = TypeVar("V")
913

1014

11-
def _deprecated_param(
12-
kwargs: Dict[str, Any], deprecated_param: str, new_param: str, default_value: Optional[W]
13-
) -> Optional[W]:
14-
warnings.warn(f"The parameter '{deprecated_param}' is deprecated, please use '{new_param}' instead.")
15-
if kwargs.pop(deprecated_param):
16-
if default_value is not None:
17-
return default_value
18-
else:
19-
raise ValueError("No checkpoint is available for model.")
20-
else:
21-
return None
15+
def handle_legacy_interface(**weights: Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]):
16+
"""Decorates a model builder with the new interface to make it compatible with the old.
17+
18+
In particular this handles two things:
19+
20+
1. Allows positional parameters again, but emits a deprecation warning in case they are used. See
21+
:func:`torchvision.prototype.utils._internal.kwonly_to_pos_or_kw` for details.
22+
2. Handles the default value change from ``pretrained=False`` to ``weights=None`` and ``pretrained=True`` to
23+
``weights=Weights`` and emits a deprecation warning with instructions for the new interface.
24+
25+
Args:
26+
**weights (Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): Deprecated parameter
27+
name and default value for the legacy ``pretrained=True``. The default value can be a callable in which
28+
case it will be called with a dictionary of the keyword arguments. The only key that is guaranteed to be in
29+
the dictionary is the deprecated parameter name passed as first element in the tuple. All other parameters
30+
should be accessed with :meth:`~dict.get`.
31+
"""
32+
33+
def outer_wrapper(builder: Callable[..., M]) -> Callable[..., M]:
34+
@kwonly_to_pos_or_kw
35+
@functools.wraps(builder)
36+
def inner_wrapper(*args: Any, **kwargs: Any) -> M:
37+
for weights_param, (pretrained_param, default) in weights.items(): # type: ignore[union-attr]
38+
# If neither the weights nor the pretrained parameter as passed, or the weights argument already use
39+
# the new style arguments, there is nothing to do. Note that we cannot use `None` as sentinel for the
40+
# weight argument, since it is a valid value.
41+
sentinel = object()
42+
weights_arg = kwargs.get(weights_param, sentinel)
43+
if (
44+
(weights_param not in kwargs and pretrained_param not in kwargs)
45+
or isinstance(weights_arg, WeightsEnum)
46+
or weights_arg is None
47+
):
48+
continue
49+
50+
# If the pretrained parameter was passed as positional argument, it is now mapped to
51+
# `kwargs[weights_param]`. This happens because the @kwonly_to_pos_or_kw decorator uses the current
52+
# signature to infer the names of positionally passed arguments and thus has no knowledge that there
53+
# used to be a pretrained parameter.
54+
pretrained_positional = weights_arg is not sentinel
55+
if pretrained_positional:
56+
# We put the pretrained argument under its legacy name in the keyword argument dictionary to have a
57+
# unified access to the value if the default value is a callable.
58+
kwargs[pretrained_param] = pretrained_arg = kwargs.pop(weights_param)
59+
else:
60+
pretrained_arg = kwargs[pretrained_param]
61+
62+
if pretrained_arg:
63+
default_weights_arg = default(kwargs) if callable(default) else default
64+
if not isinstance(default_weights_arg, WeightsEnum):
65+
raise ValueError(f"No weights available for model {builder.__name__}")
66+
else:
67+
default_weights_arg = None
68+
69+
if not pretrained_positional:
70+
warnings.warn(
71+
f"The parameter '{pretrained_param}' is deprecated, please use '{weights_param}' instead."
72+
)
73+
74+
msg = (
75+
f"Arguments other than a weight enum or `None` for '{weights_param}' are deprecated. "
76+
f"The current behavior is equivalent to passing `{weights_param}={default_weights_arg}`."
77+
)
78+
if pretrained_arg:
79+
msg = (
80+
f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.default` "
81+
f"to get the most up-to-date weights."
82+
)
83+
warnings.warn(msg)
84+
85+
del kwargs[pretrained_param]
86+
kwargs[weights_param] = default_weights_arg
87+
88+
return builder(*args, **kwargs)
2289

90+
return inner_wrapper
2391

24-
def _deprecated_positional(kwargs: Dict[str, Any], deprecated_param: str, new_param: str, default_value: V) -> None:
25-
warnings.warn(
26-
f"The positional parameter '{deprecated_param}' is deprecated, please use keyword parameter '{new_param}'"
27-
+ " instead."
28-
)
29-
kwargs[deprecated_param] = default_value
92+
return outer_wrapper
3093

3194

3295
def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> None:

torchvision/prototype/models/alexnet.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from ...models.alexnet import AlexNet
88
from ._api import WeightsEnum, Weights
99
from ._meta import _IMAGENET_CATEGORIES
10-
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
10+
from ._utils import handle_legacy_interface, _ovewrite_named_param
1111

1212

1313
__all__ = ["AlexNet", "AlexNet_Weights", "alexnet"]
@@ -29,11 +29,8 @@ class AlexNet_Weights(WeightsEnum):
2929
default = ImageNet1K_V1
3030

3131

32-
def alexnet(weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
33-
if type(weights) == bool and weights:
34-
_deprecated_positional(kwargs, "pretrained", "weights", True)
35-
if "pretrained" in kwargs:
36-
weights = _deprecated_param(kwargs, "pretrained", "weights", AlexNet_Weights.ImageNet1K_V1)
32+
@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.ImageNet1K_V1))
33+
def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
3734
weights = AlexNet_Weights.verify(weights)
3835

3936
if weights is not None:

torchvision/prototype/models/densenet.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from ...models.densenet import DenseNet
1010
from ._api import WeightsEnum, Weights
1111
from ._meta import _IMAGENET_CATEGORIES
12-
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
12+
from ._utils import handle_legacy_interface, _ovewrite_named_param
1313

1414

1515
__all__ = [
@@ -123,41 +123,29 @@ class DenseNet201_Weights(WeightsEnum):
123123
default = ImageNet1K_V1
124124

125125

126-
def densenet121(weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
127-
if type(weights) == bool and weights:
128-
_deprecated_positional(kwargs, "pretrained", "weights", True)
129-
if "pretrained" in kwargs:
130-
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet121_Weights.ImageNet1K_V1)
126+
@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.ImageNet1K_V1))
127+
def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
131128
weights = DenseNet121_Weights.verify(weights)
132129

133130
return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)
134131

135132

136-
def densenet161(weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
137-
if type(weights) == bool and weights:
138-
_deprecated_positional(kwargs, "pretrained", "weights", True)
139-
if "pretrained" in kwargs:
140-
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet161_Weights.ImageNet1K_V1)
133+
@handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.ImageNet1K_V1))
134+
def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
141135
weights = DenseNet161_Weights.verify(weights)
142136

143137
return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs)
144138

145139

146-
def densenet169(weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
147-
if type(weights) == bool and weights:
148-
_deprecated_positional(kwargs, "pretrained", "weights", True)
149-
if "pretrained" in kwargs:
150-
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet169_Weights.ImageNet1K_V1)
140+
@handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.ImageNet1K_V1))
141+
def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
151142
weights = DenseNet169_Weights.verify(weights)
152143

153144
return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs)
154145

155146

156-
def densenet201(weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
157-
if type(weights) == bool and weights:
158-
_deprecated_positional(kwargs, "pretrained", "weights", True)
159-
if "pretrained" in kwargs:
160-
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet201_Weights.ImageNet1K_V1)
147+
@handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.ImageNet1K_V1))
148+
def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
161149
weights = DenseNet201_Weights.verify(weights)
162150

163151
return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs)

0 commit comments

Comments
 (0)