Skip to content

Commit eb5254d

Browse files
committed
Fixing imports
1 parent 76f017c commit eb5254d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+245
-255
lines changed

test/test_prototype_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import test_models as TM
66
import torch
77
from common_utils import cpu_and_gpu, needs_cuda
8-
from torchvision.prototype import models
98
from torchvision.models._api import WeightsEnum, Weights
9+
from torchvision.prototype import models
1010
from torchvision.prototype.models._utils import handle_legacy_interface
1111

1212
run_if_test_with_prototype = pytest.mark.skipif(

torchvision/_utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import enum
2-
from typing import TypeVar, Type
2+
from typing import Sequence, TypeVar, Type
33

44
T = TypeVar("T", bound=enum.Enum)
55

@@ -18,3 +18,15 @@ def from_str(self: Type[T], member: str) -> T: # type: ignore[misc]
1818

1919
class StrEnum(enum.Enum, metaclass=StrEnumMeta):
2020
pass
21+
22+
23+
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
24+
if not seq:
25+
return ""
26+
if len(seq) == 1:
27+
return f"'{seq[0]}'"
28+
29+
head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'"
30+
tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'"
31+
32+
return head + tail

torchvision/models/_utils.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import functools
2+
import inspect
23
import warnings
34
from collections import OrderedDict
45
from typing import Any, Dict, Optional, TypeVar, Callable, Tuple, Union
56

67
from torch import nn
7-
from torchvision.prototype.utils._internal import kwonly_to_pos_or_kw
88

9+
from .._utils import sequence_to_str
910
from ._api import WeightsEnum
1011

1112

@@ -88,6 +89,60 @@ def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) ->
8889
return new_v
8990

9091

92+
D = TypeVar("D")
93+
94+
95+
def kwonly_to_pos_or_kw(fn: Callable[..., D]) -> Callable[..., D]:
96+
"""Decorates a function that uses keyword only parameters to also allow them being passed as positionals.
97+
98+
For example, consider the use case of changing the signature of ``old_fn`` into the one from ``new_fn``:
99+
100+
.. code::
101+
102+
def old_fn(foo, bar, baz=None):
103+
...
104+
105+
def new_fn(foo, *, bar, baz=None):
106+
...
107+
108+
Calling ``old_fn("foo", "bar, "baz")`` was valid, but the same call is no longer valid with ``new_fn``. To keep BC
109+
and at the same time warn the user of the deprecation, this decorator can be used:
110+
111+
.. code::
112+
113+
@kwonly_to_pos_or_kw
114+
def new_fn(foo, *, bar, baz=None):
115+
...
116+
117+
new_fn("foo", "bar, "baz")
118+
"""
119+
params = inspect.signature(fn).parameters
120+
121+
try:
122+
keyword_only_start_idx = next(
123+
idx for idx, param in enumerate(params.values()) if param.kind == param.KEYWORD_ONLY
124+
)
125+
except StopIteration:
126+
raise TypeError(f"Found no keyword-only parameter on function '{fn.__name__}'") from None
127+
128+
keyword_only_params = tuple(inspect.signature(fn).parameters)[keyword_only_start_idx:]
129+
130+
@functools.wraps(fn)
131+
def wrapper(*args: Any, **kwargs: Any) -> D:
132+
args, keyword_only_args = args[:keyword_only_start_idx], args[keyword_only_start_idx:]
133+
if keyword_only_args:
134+
keyword_only_kwargs = dict(zip(keyword_only_params, keyword_only_args))
135+
warnings.warn(
136+
f"Using {sequence_to_str(tuple(keyword_only_kwargs.keys()), separate_last='and ')} as positional "
137+
f"parameter(s) is deprecated. Please use keyword parameter(s) instead."
138+
)
139+
kwargs.update(keyword_only_kwargs)
140+
141+
return fn(*args, **kwargs)
142+
143+
return wrapper
144+
145+
91146
W = TypeVar("W", bound=WeightsEnum)
92147
M = TypeVar("M", bound=nn.Module)
93148
V = TypeVar("V")

torchvision/models/alexnet.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@
44
import torch
55
import torch.nn as nn
66

7+
from ..transforms import ImageClassificationEval, InterpolationMode
78
from ..utils import _log_api_usage_once
89
from ._api import WeightsEnum, Weights
910
from ._meta import _IMAGENET_CATEGORIES
1011
from ._utils import handle_legacy_interface, _ovewrite_named_param
1112

12-
from ..transforms import ImageClassificationEval, InterpolationMode
13-
1413

1514
__all__ = ["AlexNet", "AlexNet_Weights", "alexnet"]
1615

@@ -94,4 +93,4 @@ def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True,
9493
if weights is not None:
9594
model.load_state_dict(weights.get_state_dict(progress=progress))
9695

97-
return model
96+
return model

torchvision/models/convnext.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@
77

88
from ..ops.misc import Conv2dNormActivation
99
from ..ops.stochastic_depth import StochasticDepth
10-
from ..utils import _log_api_usage_once
11-
1210
from ..transforms import ImageClassificationEval, InterpolationMode
13-
11+
from ..utils import _log_api_usage_once
1412
from ._api import WeightsEnum, Weights
1513
from ._meta import _IMAGENET_CATEGORIES
1614
from ._utils import handle_legacy_interface, _ovewrite_named_param

torchvision/models/densenet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import re
2-
from functools import partial
32
from collections import OrderedDict
3+
from functools import partial
44
from typing import Any, List, Optional, Tuple
55

66
import torch
@@ -11,7 +11,6 @@
1111

1212
from ..transforms import ImageClassificationEval, InterpolationMode
1313
from ..utils import _log_api_usage_once
14-
1514
from ._api import WeightsEnum, Weights
1615
from ._meta import _IMAGENET_CATEGORIES
1716
from ._utils import handle_legacy_interface, _ovewrite_named_param
@@ -277,6 +276,7 @@ def _densenet(
277276
"recipe": "https://github.com/pytorch/vision/pull/116",
278277
}
279278

279+
280280
class DenseNet121_Weights(WeightsEnum):
281281
IMAGENET1K_V1 = Weights(
282282
url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
@@ -398,4 +398,4 @@ def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool
398398
"""
399399
weights = DenseNet201_Weights.verify(weights)
400400

401-
return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs)
401+
return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs)

torchvision/models/efficientnet.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from ..ops.misc import Conv2dNormActivation, SqueezeExcitation
1313
from ..transforms import ImageClassificationEval, InterpolationMode
1414
from ..utils import _log_api_usage_once
15-
1615
from ._api import WeightsEnum, Weights
1716
from ._meta import _IMAGENET_CATEGORIES
1817
from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible

torchvision/models/googlenet.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
import warnings
2-
from functools import partial
32
from collections import namedtuple
3+
from functools import partial
44
from typing import Optional, Tuple, List, Callable, Any
55

66
import torch
77
import torch.nn as nn
88
import torch.nn.functional as F
99
from torch import Tensor
1010

11-
from ..utils import _log_api_usage_once
1211
from ..transforms import ImageClassificationEval, InterpolationMode
13-
12+
from ..utils import _log_api_usage_once
1413
from ._api import WeightsEnum, Weights
1514
from ._meta import _IMAGENET_CATEGORIES
1615
from ._utils import handle_legacy_interface, _ovewrite_named_param
@@ -333,4 +332,4 @@ def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = T
333332
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
334333
)
335334

336-
return model
335+
return model

torchvision/models/inception.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@
77
import torch.nn.functional as F
88
from torch import nn, Tensor
99

10-
from ..utils import _log_api_usage_once
11-
12-
1310
from ..transforms import ImageClassificationEval, InterpolationMode
11+
from ..utils import _log_api_usage_once
1412
from ._api import WeightsEnum, Weights
1513
from ._meta import _IMAGENET_CATEGORIES
1614
from ._utils import handle_legacy_interface, _ovewrite_named_param
@@ -465,4 +463,4 @@ def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bo
465463
model.aux_logits = False
466464
model.AuxLogits = None
467465

468-
return model
466+
return model

torchvision/models/mnasnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from functools import partial
21
import warnings
2+
from functools import partial
33
from typing import Any, Dict, List, Optional
44

55
import torch

torchvision/models/mobilenetv2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from functools import partial
21
import warnings
2+
from functools import partial
33
from typing import Callable, Any, Optional, List
44

55
import torch
@@ -9,7 +9,6 @@
99
from ..ops.misc import Conv2dNormActivation
1010
from ..transforms import ImageClassificationEval, InterpolationMode
1111
from ..utils import _log_api_usage_once
12-
1312
from ._api import WeightsEnum, Weights
1413
from ._meta import _IMAGENET_CATEGORIES
1514
from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible

torchvision/models/mobilenetv3.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from ..ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer
99
from ..transforms import ImageClassificationEval, InterpolationMode
1010
from ..utils import _log_api_usage_once
11-
1211
from ._api import WeightsEnum, Weights
1312
from ._meta import _IMAGENET_CATEGORIES
1413
from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible

torchvision/models/quantization/googlenet.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,21 @@
55
import torch.nn as nn
66
from torch import Tensor
77
from torch.nn import functional as F
8-
from torchvision.models.googlenet import GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet, model_urls
8+
from torchvision.models.googlenet import GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet
99

1010
from ..._internally_replaced_utils import load_state_dict_from_url
1111
from .utils import _fuse_modules, _replace_relu, quantize_model
1212

1313

1414
__all__ = ["QuantizableGoogLeNet", "googlenet"]
1515

16+
17+
model_urls = {
18+
# GoogLeNet ported from TensorFlow
19+
"googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth",
20+
}
21+
22+
1623
quant_model_urls = {
1724
# fp32 GoogLeNet ported from TensorFlow, with weights quantized in PyTorch
1825
"googlenet_fbgemm": "https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth",

torchvision/models/quantization/inception.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818
]
1919

2020

21+
model_urls = {
22+
# Inception v3 ported from TensorFlow
23+
"inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
24+
}
25+
26+
2127
quant_model_urls = {
2228
# fp32 weights ported from TensorFlow, quantized in PyTorch
2329
"inception_v3_google_fbgemm": "https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth"
@@ -225,7 +231,7 @@ def inception_v3(
225231
model.AuxLogits = None
226232
model_url = quant_model_urls["inception_v3_google_" + backend]
227233
else:
228-
model_url = inception_module.model_urls["inception_v3_google"]
234+
model_url = model_urls["inception_v3_google"]
229235

230236
state_dict = load_state_dict_from_url(model_url, progress=progress)
231237

torchvision/models/quantization/mobilenetv2.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torch import Tensor
44
from torch import nn
55
from torch.ao.quantization import QuantStub, DeQuantStub
6-
from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls
6+
from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2
77

88
from ..._internally_replaced_utils import load_state_dict_from_url
99
from ...ops.misc import Conv2dNormActivation
@@ -12,6 +12,12 @@
1212

1313
__all__ = ["QuantizableMobileNetV2", "mobilenet_v2"]
1414

15+
16+
model_urls = {
17+
"mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
18+
}
19+
20+
1521
quant_model_urls = {
1622
"mobilenet_v2_qnnpack": "https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth"
1723
}

torchvision/models/quantization/mobilenetv3.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,18 @@
66

77
from ..._internally_replaced_utils import load_state_dict_from_url
88
from ...ops.misc import Conv2dNormActivation, SqueezeExcitation
9-
from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3, model_urls, _mobilenet_v3_conf
9+
from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3, _mobilenet_v3_conf
1010
from .utils import _fuse_modules, _replace_relu
1111

1212

1313
__all__ = ["QuantizableMobileNetV3", "mobilenet_v3_large"]
1414

15+
16+
model_urls = {
17+
"mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
18+
}
19+
20+
1521
quant_model_urls = {
1622
"mobilenet_v3_large_qnnpack": "https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth",
1723
}

torchvision/models/quantization/resnet.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,21 @@
33
import torch
44
import torch.nn as nn
55
from torch import Tensor
6-
from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls
6+
from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet
77

88
from ..._internally_replaced_utils import load_state_dict_from_url
99
from .utils import _fuse_modules, _replace_relu, quantize_model
1010

1111
__all__ = ["QuantizableResNet", "resnet18", "resnet50", "resnext101_32x8d"]
1212

1313

14+
model_urls = {
15+
"resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
16+
"resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
17+
"resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
18+
}
19+
20+
1421
quant_model_urls = {
1522
"resnet18_fbgemm": "https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth",
1623
"resnet50_fbgemm": "https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth",

torchvision/models/quantization/shufflenetv2.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@
1414
"shufflenet_v2_x1_0",
1515
]
1616

17+
18+
model_urls = {
19+
"shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
20+
"shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
21+
}
22+
23+
1724
quant_model_urls = {
1825
"shufflenetv2_x0.5_fbgemm": "https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth",
1926
"shufflenetv2_x1.0_fbgemm": "https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth",
@@ -96,7 +103,7 @@ def _shufflenetv2(
96103
if quantize:
97104
model_url = quant_model_urls[arch + "_" + backend]
98105
else:
99-
model_url = shufflenetv2.model_urls[arch]
106+
model_url = model_urls[arch]
100107

101108
state_dict = load_state_dict_from_url(model_url, progress=progress)
102109

0 commit comments

Comments
 (0)