Skip to content

Commit c25ec4a

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] Adding new ResNet50 weights (#4734)
Summary: * Update model checkpoint for resnet50. * Add get_weight method to retrieve weights from name. * Update the references to support prototype weights. * Fixing mypy typing. * Switching to a python 3.6 supported equivalent. * Add unit-test. * Add optional num_classes. Reviewed By: NicolasHug Differential Revision: D31916330 fbshipit-source-id: 2ac0f9202f62a78078b0917e6730d7fc0925acdf
1 parent b4d3b5c commit c25ec4a

File tree

4 files changed

+72
-9
lines changed

4 files changed

+72
-9
lines changed

references/classification/train.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414
from torchvision.transforms.functional import InterpolationMode
1515

1616

17+
try:
18+
from torchvision.prototype import models as PM
19+
except ImportError:
20+
PM = None
21+
22+
1723
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
1824
model.train()
1925
metric_logger = utils.MetricLogger(delimiter=" ")
@@ -142,11 +148,18 @@ def load_data(traindir, valdir, args):
142148
print("Loading dataset_test from {}".format(cache_path))
143149
dataset_test, _ = torch.load(cache_path)
144150
else:
151+
if not args.weights:
152+
preprocessing = presets.ClassificationPresetEval(
153+
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
154+
)
155+
else:
156+
fn = PM.__dict__[args.model]
157+
weights = PM._api.get_weight(fn, args.weights)
158+
preprocessing = weights.transforms()
159+
145160
dataset_test = torchvision.datasets.ImageFolder(
146161
valdir,
147-
presets.ClassificationPresetEval(
148-
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
149-
),
162+
preprocessing,
150163
)
151164
if args.cache_dataset:
152165
print("Saving dataset_test to {}".format(cache_path))
@@ -206,7 +219,12 @@ def main(args):
206219
)
207220

208221
print("Creating model")
209-
model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes)
222+
if not args.weights:
223+
model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes)
224+
else:
225+
if PM is None:
226+
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
227+
model = PM.__dict__[args.model](weights=args.weights, num_classes=num_classes)
210228
model.to(device)
211229

212230
if args.distributed and args.sync_bn:
@@ -455,6 +473,9 @@ def get_args_parser(add_help=True):
455473
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
456474
)
457475

476+
# Prototype models only
477+
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
478+
458479
return parser
459480

460481

test/test_prototype_models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ def get_available_classification_models():
1212
return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
1313

1414

15+
def test_get_weight():
16+
fn = models.resnet50
17+
weight_name = "ImageNet1K_RefV2"
18+
assert models._api.get_weight(fn, weight_name) == models.ResNet50Weights.ImageNet1K_RefV2
19+
20+
1521
@pytest.mark.parametrize("model_name", get_available_classification_models())
1622
@pytest.mark.parametrize("dev", cpu_and_gpu())
1723
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")

torchvision/prototype/models/_api.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from collections import OrderedDict
22
from dataclasses import dataclass, fields
33
from enum import Enum
4+
from inspect import signature
45
from typing import Any, Callable, Dict
56

67
from ..._internally_replaced_utils import load_state_dict_from_url
78

89

9-
__all__ = ["Weights", "WeightEntry"]
10+
__all__ = ["Weights", "WeightEntry", "get_weight"]
1011

1112

1213
@dataclass
@@ -74,3 +75,38 @@ def __getattr__(self, name):
7475
if f.name == name:
7576
return object.__getattribute__(self.value, name)
7677
return super().__getattr__(name)
78+
79+
80+
def get_weight(fn: Callable, weight_name: str) -> Weights:
81+
"""
82+
Gets the weight enum of a specific model builder method and weight name combination.
83+
84+
Args:
85+
fn (Callable): The builder method used to create the model.
86+
weight_name (str): The name of the weight enum entry of the specific model.
87+
88+
Returns:
89+
Weights: The requested weight enum.
90+
"""
91+
sig = signature(fn)
92+
if "weights" not in sig.parameters:
93+
raise ValueError("The method is missing the 'weights' argument.")
94+
95+
ann = signature(fn).parameters["weights"].annotation
96+
weights_class = None
97+
if isinstance(ann, type) and issubclass(ann, Weights):
98+
weights_class = ann
99+
else:
100+
# handle cases like Union[Optional, T]
101+
# TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8
102+
for t in ann.__args__: # type: ignore[union-attr]
103+
if isinstance(t, type) and issubclass(t, Weights):
104+
weights_class = t
105+
break
106+
107+
if weights_class is None:
108+
raise ValueError(
109+
"The weight class for the specific method couldn't be retrieved. Make sure the typing info is " "correct."
110+
)
111+
112+
return weights_class.from_str(weight_name)

torchvision/prototype/models/resnet.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,13 @@ class ResNet50Weights(Weights):
9292
},
9393
)
9494
ImageNet1K_RefV2 = WeightEntry(
95-
url="https://download.pytorch.org/models/resnet50-tmp.pth",
96-
transforms=partial(ImageNetEval, crop_size=224),
95+
url="https://download.pytorch.org/models/resnet50-f46c3f97.pth",
96+
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
9797
meta={
9898
**_common_meta,
9999
"recipe": "https://github.com/pytorch/vision/issues/3995",
100-
"acc@1": 80.352,
101-
"acc@5": 95.148,
100+
"acc@1": 80.674,
101+
"acc@5": 95.166,
102102
},
103103
)
104104

0 commit comments

Comments
 (0)