Skip to content

Commit af4a232

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] Adding Weights classes for Resnet classification models (#4655)
Summary: * adding Weights classes for Resnet classification models * Replacing BasicBlock by Bottleneck in all but 3 model contructors * adding tests for prototype models * fixing typo in environment variable * Update test/test_prototype_models.py * changing default value for PYTORCH_TEST_WITH_PROTOTYPE * adding checks to compare outputs of the prototype vs old models * refactoring prototype tests * removing unused imports * applying ufmt * Update test/test_prototype_models.py * Update test/test_prototype_models.py * Update test/test_prototype_models.py * Update test/test_prototype_models.py * Update test/test_prototype_models.py * Update test/test_prototype_models.py * Update test/test_prototype_models.py Reviewed By: NicolasHug Differential Revision: D31916320 fbshipit-source-id: 6a379c2318de8448bf2b89174e789b866071eba5 Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 4cb496d commit af4a232

File tree

2 files changed

+261
-1
lines changed

2 files changed

+261
-1
lines changed

test/test_prototype_models.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,56 @@
1+
import os
2+
3+
import pytest
4+
import torch
5+
from common_utils import set_rng_seed, cpu_and_gpu
6+
from test_models import _assert_expected, _model_params
7+
from torchvision import models as original_models
8+
from torchvision.prototype import models
9+
10+
11+
def get_available_classification_models():
12+
return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
13+
14+
15+
@pytest.mark.parametrize("model_name", get_available_classification_models())
16+
@pytest.mark.parametrize("dev", cpu_and_gpu())
17+
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
18+
def test_classification_model(model_name, dev):
19+
set_rng_seed(0)
20+
defaults = {
21+
"num_classes": 50,
22+
"input_shape": (1, 3, 224, 224),
23+
}
24+
kwargs = {**defaults, **_model_params.get(model_name, {})}
25+
input_shape = kwargs.pop("input_shape")
26+
model = models.__dict__[model_name](**kwargs)
27+
model.eval().to(device=dev)
28+
x = torch.rand(input_shape).to(device=dev)
29+
out = model(x)
30+
_assert_expected(out.cpu(), model_name, prec=0.1)
31+
assert out.shape[-1] == 50
32+
33+
34+
@pytest.mark.parametrize("model_name", get_available_classification_models())
35+
@pytest.mark.parametrize("dev", cpu_and_gpu())
36+
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
37+
def test_old_vs_new_classification_factory(model_name, dev):
38+
defaults = {
39+
"pretrained": True,
40+
"input_shape": (1, 3, 224, 224),
41+
}
42+
kwargs = {**defaults, **_model_params.get(model_name, {})}
43+
input_shape = kwargs.pop("input_shape")
44+
model_old = original_models.__dict__[model_name](**kwargs)
45+
model_old.eval().to(device=dev)
46+
x = torch.rand(input_shape).to(device=dev)
47+
out_old = model_old(x)
48+
# compare with new model builder parameterized in the old fashion way
49+
model_new = models.__dict__[model_name](**kwargs)
50+
model_new.eval().to(device=dev)
51+
out_new = model_new(x)
52+
torch.testing.assert_close(out_new, out_old, rtol=0.0, atol=0.0, check_dtype=False)
53+
54+
155
def test_smoke():
256
import torchvision.prototype.models # noqa: F401

torchvision/prototype/models/resnet.py

Lines changed: 207 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,27 @@
88
from ._meta import _IMAGENET_CATEGORIES
99

1010

11-
__all__ = ["ResNet", "ResNet50Weights", "resnet50"]
11+
__all__ = [
12+
"ResNet",
13+
"ResNet18Weights",
14+
"ResNet34Weights",
15+
"ResNet50Weights",
16+
"ResNet101Weights",
17+
"ResNet152Weights",
18+
"ResNeXt50_32x4dWeights",
19+
"ResNeXt101_32x8dWeights",
20+
"WideResNet50_2Weights",
21+
"WideResNet101_2Weights",
22+
"resnet18",
23+
"resnet34",
24+
"resnet50",
25+
"resnet101",
26+
"resnet152",
27+
"resnext50_32x4d",
28+
"resnext101_32x8d",
29+
"wide_resnet50_2",
30+
"wide_resnet101_2",
31+
]
1232

1333

1434
def _resnet(
@@ -35,6 +55,32 @@ def _resnet(
3555
}
3656

3757

58+
class ResNet18Weights(Weights):
59+
ImageNet1K_RefV1 = WeightEntry(
60+
url="https://download.pytorch.org/models/resnet18-f37072fd.pth",
61+
transforms=partial(ImageNetEval, crop_size=224),
62+
meta={
63+
**_common_meta,
64+
"recipe": "",
65+
"acc@1": 69.758,
66+
"acc@5": 89.078,
67+
},
68+
)
69+
70+
71+
class ResNet34Weights(Weights):
72+
ImageNet1K_RefV1 = WeightEntry(
73+
url="https://download.pytorch.org/models/resnet34-b627a593.pth",
74+
transforms=partial(ImageNetEval, crop_size=224),
75+
meta={
76+
**_common_meta,
77+
"recipe": "",
78+
"acc@1": 73.314,
79+
"acc@5": 91.420,
80+
},
81+
)
82+
83+
3884
class ResNet50Weights(Weights):
3985
ImageNet1K_RefV1 = WeightEntry(
4086
url="https://download.pytorch.org/models/resnet50-0676ba61.pth",
@@ -58,10 +104,170 @@ class ResNet50Weights(Weights):
58104
)
59105

60106

107+
class ResNet101Weights(Weights):
108+
ImageNet1K_RefV1 = WeightEntry(
109+
url="https://download.pytorch.org/models/resnet101-63fe2227.pth",
110+
transforms=partial(ImageNetEval, crop_size=224),
111+
meta={
112+
**_common_meta,
113+
"recipe": "",
114+
"acc@1": 77.374,
115+
"acc@5": 93.546,
116+
},
117+
)
118+
119+
120+
class ResNet152Weights(Weights):
121+
ImageNet1K_RefV1 = WeightEntry(
122+
url="https://download.pytorch.org/models/resnet152-394f9c45.pth",
123+
transforms=partial(ImageNetEval, crop_size=224),
124+
meta={
125+
**_common_meta,
126+
"recipe": "",
127+
"acc@1": 78.312,
128+
"acc@5": 94.046,
129+
},
130+
)
131+
132+
133+
class ResNeXt50_32x4dWeights(Weights):
134+
ImageNet1K_RefV1 = WeightEntry(
135+
url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
136+
transforms=partial(ImageNetEval, crop_size=224),
137+
meta={
138+
**_common_meta,
139+
"recipe": "",
140+
"acc@1": 77.618,
141+
"acc@5": 93.698,
142+
},
143+
)
144+
145+
146+
class ResNeXt101_32x8dWeights(Weights):
147+
ImageNet1K_RefV1 = WeightEntry(
148+
url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
149+
transforms=partial(ImageNetEval, crop_size=224),
150+
meta={
151+
**_common_meta,
152+
"recipe": "",
153+
"acc@1": 79.312,
154+
"acc@5": 94.526,
155+
},
156+
)
157+
158+
159+
class WideResNet50_2Weights(Weights):
160+
ImageNet1K_RefV1 = WeightEntry(
161+
url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
162+
transforms=partial(ImageNetEval, crop_size=224),
163+
meta={
164+
**_common_meta,
165+
"recipe": "",
166+
"acc@1": 78.468,
167+
"acc@5": 94.086,
168+
},
169+
)
170+
171+
172+
class WideResNet101_2Weights(Weights):
173+
ImageNet1K_RefV1 = WeightEntry(
174+
url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
175+
transforms=partial(ImageNetEval, crop_size=224),
176+
meta={
177+
**_common_meta,
178+
"recipe": "",
179+
"acc@1": 78.848,
180+
"acc@5": 94.284,
181+
},
182+
)
183+
184+
185+
def resnet18(weights: Optional[ResNet18Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
186+
if "pretrained" in kwargs:
187+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
188+
weights = ResNet18Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
189+
190+
weights = ResNet18Weights.verify(weights)
191+
192+
return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)
193+
194+
195+
def resnet34(weights: Optional[ResNet34Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
196+
if "pretrained" in kwargs:
197+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
198+
weights = ResNet34Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
199+
200+
weights = ResNet34Weights.verify(weights)
201+
202+
return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs)
203+
204+
61205
def resnet50(weights: Optional[ResNet50Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
62206
if "pretrained" in kwargs:
63207
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
64208
weights = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
65209
weights = ResNet50Weights.verify(weights)
66210

67211
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
212+
213+
214+
def resnet101(weights: Optional[ResNet101Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
215+
if "pretrained" in kwargs:
216+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
217+
weights = ResNet101Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
218+
219+
weights = ResNet101Weights.verify(weights)
220+
221+
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
222+
223+
224+
def resnet152(weights: Optional[ResNet152Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
225+
if "pretrained" in kwargs:
226+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
227+
weights = ResNet152Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
228+
229+
weights = ResNet152Weights.verify(weights)
230+
231+
return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs)
232+
233+
234+
def resnext50_32x4d(weights: Optional[ResNeXt50_32x4dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
235+
if "pretrained" in kwargs:
236+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
237+
weights = ResNeXt50_32x4dWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
238+
239+
weights = ResNeXt50_32x4dWeights.verify(weights)
240+
kwargs["groups"] = 32
241+
kwargs["width_per_group"] = 4
242+
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
243+
244+
245+
def resnext101_32x8d(weights: Optional[ResNeXt101_32x8dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
246+
if "pretrained" in kwargs:
247+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
248+
weights = ResNeXt101_32x8dWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
249+
250+
weights = ResNeXt101_32x8dWeights.verify(weights)
251+
kwargs["groups"] = 32
252+
kwargs["width_per_group"] = 8
253+
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
254+
255+
256+
def wide_resnet50_2(weights: Optional[WideResNet50_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
257+
if "pretrained" in kwargs:
258+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
259+
weights = WideResNet50_2Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
260+
261+
weights = WideResNet50_2Weights.verify(weights)
262+
kwargs["width_per_group"] = 64 * 2
263+
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
264+
265+
266+
def wide_resnet101_2(weights: Optional[WideResNet101_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
267+
if "pretrained" in kwargs:
268+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
269+
weights = WideResNet101_2Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
270+
271+
weights = WideResNet101_2Weights.verify(weights)
272+
kwargs["width_per_group"] = 64 * 2
273+
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)

0 commit comments

Comments
 (0)