Skip to content

Commit bec45cd

Browse files
authored
Adding multiweight support for squeezenet prototype model (#4817)
* Change enum name for weights contributed by community. * Adding multiweight support to squeezenet.
1 parent c4b5b67 commit bec45cd

File tree

4 files changed

+81
-6
lines changed

4 files changed

+81
-6
lines changed

torchvision/prototype/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .regnet import *
99
from .resnet import *
1010
from .shufflenetv2 import *
11+
from .squeezenet import *
1112
from .vgg import *
1213
from . import detection
1314
from . import quantization

torchvision/prototype/models/googlenet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818

1919
class GoogLeNetWeights(Weights):
20-
ImageNet1K_TheCodezV1 = WeightEntry(
20+
ImageNet1K_Community = WeightEntry(
2121
url="https://download.pytorch.org/models/googlenet-1378be20.pth",
2222
transforms=partial(ImageNetEval, crop_size=224),
2323
meta={
@@ -32,7 +32,7 @@ class GoogLeNetWeights(Weights):
3232
def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet:
3333
if "pretrained" in kwargs:
3434
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
35-
weights = GoogLeNetWeights.ImageNet1K_TheCodezV1 if kwargs.pop("pretrained") else None
35+
weights = GoogLeNetWeights.ImageNet1K_Community if kwargs.pop("pretrained") else None
3636
weights = GoogLeNetWeights.verify(weights)
3737

3838
original_aux_logits = kwargs.get("aux_logits", False)

torchvision/prototype/models/mnasnet.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828

2929
class MNASNet0_5Weights(Weights):
30-
ImageNet1K_TrainerV1 = WeightEntry(
30+
ImageNet1K_Community = WeightEntry(
3131
url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
3232
transforms=partial(ImageNetEval, crop_size=224),
3333
meta={
@@ -45,7 +45,7 @@ class MNASNet0_75Weights(Weights):
4545

4646

4747
class MNASNet1_0Weights(Weights):
48-
ImageNet1K_TrainerV1 = WeightEntry(
48+
ImageNet1K_Community = WeightEntry(
4949
url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
5050
transforms=partial(ImageNetEval, crop_size=224),
5151
meta={
@@ -77,7 +77,7 @@ def _mnasnet(alpha: float, weights: Optional[Weights], progress: bool, **kwargs:
7777
def mnasnet0_5(weights: Optional[MNASNet0_5Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
7878
if "pretrained" in kwargs:
7979
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
80-
weights = MNASNet0_5Weights.ImageNet1K_TrainerV1 if kwargs.pop("pretrained") else None
80+
weights = MNASNet0_5Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
8181

8282
weights = MNASNet0_5Weights.verify(weights)
8383

@@ -98,7 +98,7 @@ def mnasnet0_75(weights: Optional[MNASNet0_75Weights] = None, progress: bool = T
9898
def mnasnet1_0(weights: Optional[MNASNet1_0Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
9999
if "pretrained" in kwargs:
100100
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
101-
weights = MNASNet1_0Weights.ImageNet1K_TrainerV1 if kwargs.pop("pretrained") else None
101+
weights = MNASNet1_0Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
102102
weights = MNASNet1_0Weights.verify(weights)
103103

104104
return _mnasnet(1.0, weights, progress, **kwargs)
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import warnings
2+
from functools import partial
3+
from typing import Any, Optional
4+
5+
from torchvision.transforms.functional import InterpolationMode
6+
7+
from ...models.squeezenet import SqueezeNet
8+
from ..transforms.presets import ImageNetEval
9+
from ._api import Weights, WeightEntry
10+
from ._meta import _IMAGENET_CATEGORIES
11+
12+
13+
__all__ = ["SqueezeNet", "SqueezeNet1_0Weights", "SqueezeNet1_1Weights", "squeezenet1_0", "squeezenet1_1"]
14+
15+
16+
_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
17+
18+
19+
class SqueezeNet1_0Weights(Weights):
20+
ImageNet1K_Community = WeightEntry(
21+
url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
22+
transforms=partial(ImageNetEval, crop_size=224),
23+
meta={
24+
**_common_meta,
25+
"recipe": "https://github.com/pytorch/vision/pull/49#issuecomment-277560717",
26+
"acc@1": 58.092,
27+
"acc@5": 80.420,
28+
},
29+
)
30+
31+
32+
class SqueezeNet1_1Weights(Weights):
33+
ImageNet1K_Community = WeightEntry(
34+
url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
35+
transforms=partial(ImageNetEval, crop_size=224),
36+
meta={
37+
**_common_meta,
38+
"recipe": "https://github.com/pytorch/vision/pull/49#issuecomment-277560717",
39+
"acc@1": 58.178,
40+
"acc@5": 80.624,
41+
},
42+
)
43+
44+
45+
def squeezenet1_0(weights: Optional[SqueezeNet1_0Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet:
46+
if "pretrained" in kwargs:
47+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
48+
weights = SqueezeNet1_0Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
49+
weights = SqueezeNet1_0Weights.verify(weights)
50+
if weights is not None:
51+
kwargs["num_classes"] = len(weights.meta["categories"])
52+
53+
model = SqueezeNet("1_0", **kwargs)
54+
55+
if weights is not None:
56+
model.load_state_dict(weights.state_dict(progress=progress))
57+
58+
return model
59+
60+
61+
def squeezenet1_1(weights: Optional[SqueezeNet1_1Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet:
62+
if "pretrained" in kwargs:
63+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
64+
weights = SqueezeNet1_1Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
65+
weights = SqueezeNet1_1Weights.verify(weights)
66+
if weights is not None:
67+
kwargs["num_classes"] = len(weights.meta["categories"])
68+
69+
model = SqueezeNet("1_1", **kwargs)
70+
71+
if weights is not None:
72+
model.load_state_dict(weights.state_dict(progress=progress))
73+
74+
return model

0 commit comments

Comments
 (0)