Skip to content

Commit 2cd25c1

Browse files
authored
Fix resnet_fpn_backbone(pretrained=True) (#7172)
1 parent 135a0f9 commit 2cd25c1

File tree

3 files changed

+12
-7
lines changed

3 files changed

+12
-7
lines changed

test/test_extended_models.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from torchvision import models
1010
from torchvision.models import get_model_weights, Weights, WeightsEnum
1111
from torchvision.models._utils import handle_legacy_interface
12+
from torchvision.models.detection.backbone_utils import mobilenet_backbone, resnet_fpn_backbone
1213

1314
run_if_test_with_extended = pytest.mark.skipif(
1415
os.getenv("PYTORCH_TEST_WITH_EXTENDED", "0") != "1",
@@ -425,7 +426,11 @@ def builder(*, weights=None, flag):
425426
+ TM.list_model_fns(models.quantization)
426427
+ TM.list_model_fns(models.segmentation)
427428
+ TM.list_model_fns(models.video)
428-
+ TM.list_model_fns(models.optical_flow),
429+
+ TM.list_model_fns(models.optical_flow)
430+
+ [
431+
lambda pretrained: resnet_fpn_backbone(backbone_name="resnet50", pretrained=pretrained),
432+
lambda pretrained: mobilenet_backbone(backbone_name="mobilenet_v2", fpn=False, pretrained=pretrained),
433+
],
429434
)
430435
@run_if_test_with_extended
431436
def test_pretrained_deprecation(self, model_fn):

torchvision/models/_api.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from functools import partial
77
from inspect import signature
88
from types import ModuleType
9-
from typing import Any, Callable, cast, Dict, List, Mapping, Optional, TypeVar, Union
9+
from typing import Any, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union
1010

1111
from torch import nn
1212

@@ -138,7 +138,7 @@ def get_weight(name: str) -> WeightsEnum:
138138
return weights_enum[value_name]
139139

140140

141-
def get_model_weights(name: Union[Callable, str]) -> WeightsEnum:
141+
def get_model_weights(name: Union[Callable, str]) -> Type[WeightsEnum]:
142142
"""
143143
Returns the weights enum class associated to the given model.
144144
@@ -152,7 +152,7 @@ def get_model_weights(name: Union[Callable, str]) -> WeightsEnum:
152152
return _get_enum_from_fn(model)
153153

154154

155-
def _get_enum_from_fn(fn: Callable) -> WeightsEnum:
155+
def _get_enum_from_fn(fn: Callable) -> Type[WeightsEnum]:
156156
"""
157157
Internal method that gets the weight enum of a specific model builder method.
158158
@@ -182,7 +182,7 @@ def _get_enum_from_fn(fn: Callable) -> WeightsEnum:
182182
"The WeightsEnum class for the specific method couldn't be retrieved. Make sure the typing info is correct."
183183
)
184184

185-
return cast(WeightsEnum, weights_enum)
185+
return weights_enum
186186

187187

188188
M = TypeVar("M", bound=nn.Module)

torchvision/models/detection/backbone_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]:
6262
@handle_legacy_interface(
6363
weights=(
6464
"pretrained",
65-
lambda kwargs: _get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"),
65+
lambda kwargs: _get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
6666
),
6767
)
6868
def resnet_fpn_backbone(
@@ -177,7 +177,7 @@ def _validate_trainable_layers(
177177
@handle_legacy_interface(
178178
weights=(
179179
"pretrained",
180-
lambda kwargs: _get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"),
180+
lambda kwargs: _get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
181181
),
182182
)
183183
def mobilenet_backbone(

0 commit comments

Comments
 (0)