Skip to content

Commit 135a0f9

Browse files
authored
Make WeightEnum and Weights public + cleanups (#7100)
1 parent cb8c441 commit 135a0f9

File tree

3 files changed

+23
-17
lines changed

3 files changed

+23
-17
lines changed

test/test_extended_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
from common_extended_utils import get_file_size_mb, get_ops
99
from torchvision import models
10-
from torchvision.models._api import get_model_weights, Weights, WeightsEnum
10+
from torchvision.models import get_model_weights, Weights, WeightsEnum
1111
from torchvision.models._utils import handle_legacy_interface
1212

1313
run_if_test_with_extended = pytest.mark.skipif(

torchvision/models/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,9 @@
1515
from .swin_transformer import *
1616
from .maxvit import *
1717
from . import detection, optical_flow, quantization, segmentation, video
18-
from ._api import get_model, get_model_builder, get_model_weights, get_weight, list_models
18+
19+
# The Weights and WeightsEnum are developer-facing utils that we make public for
20+
# downstream libs like torchgeo https://github.com/pytorch/vision/issues/7094
21+
# TODO: we could / should document them publicly, but it's not clear where, as
22+
# they're not intended for end users.
23+
from ._api import get_model, get_model_builder, get_model_weights, get_weight, list_models, Weights, WeightsEnum

torchvision/models/_api.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
import importlib
22
import inspect
33
import sys
4-
from dataclasses import dataclass, fields
4+
from dataclasses import dataclass
5+
from enum import Enum
56
from functools import partial
67
from inspect import signature
78
from types import ModuleType
89
from typing import Any, Callable, cast, Dict, List, Mapping, Optional, TypeVar, Union
910

1011
from torch import nn
1112

12-
from torchvision._utils import StrEnum
13-
1413
from .._internally_replaced_utils import load_state_dict_from_url
1514

1615

@@ -65,7 +64,7 @@ def __eq__(self, other: Any) -> bool:
6564
return self.transforms == other.transforms
6665

6766

68-
class WeightsEnum(StrEnum):
67+
class WeightsEnum(Enum):
6968
"""
7069
This class is the parent class of all model weights. Each model building method receives an optional `weights`
7170
parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type
@@ -75,14 +74,11 @@ class WeightsEnum(StrEnum):
7574
value (Weights): The data class entry with the weight information.
7675
"""
7776

78-
def __init__(self, value: Weights):
79-
self._value_ = value
80-
8177
@classmethod
8278
def verify(cls, obj: Any) -> Any:
8379
if obj is not None:
8480
if type(obj) is str:
85-
obj = cls.from_str(obj.replace(cls.__name__ + ".", ""))
81+
obj = cls[obj.replace(cls.__name__ + ".", "")]
8682
elif not isinstance(obj, cls):
8783
raise TypeError(
8884
f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}."
@@ -95,12 +91,17 @@ def get_state_dict(self, progress: bool) -> Mapping[str, Any]:
9591
def __repr__(self) -> str:
9692
return f"{self.__class__.__name__}.{self._name_}"
9793

98-
def __getattr__(self, name):
99-
# Be able to fetch Weights attributes directly
100-
for f in fields(Weights):
101-
if f.name == name:
102-
return object.__getattribute__(self.value, name)
103-
return super().__getattr__(name)
94+
@property
95+
def url(self):
96+
return self.value.url
97+
98+
@property
99+
def transforms(self):
100+
return self.value.transforms
101+
102+
@property
103+
def meta(self):
104+
return self.value.meta
104105

105106

106107
def get_weight(name: str) -> WeightsEnum:
@@ -134,7 +135,7 @@ def get_weight(name: str) -> WeightsEnum:
134135
if weights_enum is None:
135136
raise ValueError(f"The weight enum '{enum_name}' for the specific method couldn't be retrieved.")
136137

137-
return weights_enum.from_str(value_name)
138+
return weights_enum[value_name]
138139

139140

140141
def get_model_weights(name: Union[Callable, str]) -> WeightsEnum:

0 commit comments

Comments
 (0)