1
1
import importlib
2
2
import inspect
3
3
import sys
4
- from dataclasses import dataclass , fields
4
+ from dataclasses import dataclass
5
+ from enum import Enum
5
6
from functools import partial
6
7
from inspect import signature
7
8
from types import ModuleType
8
9
from typing import Any , Callable , cast , Dict , List , Mapping , Optional , TypeVar , Union
9
10
10
11
from torch import nn
11
12
12
- from torchvision ._utils import StrEnum
13
-
14
13
from .._internally_replaced_utils import load_state_dict_from_url
15
14
16
15
@@ -65,7 +64,7 @@ def __eq__(self, other: Any) -> bool:
65
64
return self .transforms == other .transforms
66
65
67
66
68
- class WeightsEnum (StrEnum ):
67
+ class WeightsEnum (Enum ):
69
68
"""
70
69
This class is the parent class of all model weights. Each model building method receives an optional `weights`
71
70
parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type
@@ -75,14 +74,11 @@ class WeightsEnum(StrEnum):
75
74
value (Weights): The data class entry with the weight information.
76
75
"""
77
76
78
- def __init__ (self , value : Weights ):
79
- self ._value_ = value
80
-
81
77
@classmethod
82
78
def verify (cls , obj : Any ) -> Any :
83
79
if obj is not None :
84
80
if type (obj ) is str :
85
- obj = cls . from_str ( obj .replace (cls .__name__ + "." , "" ))
81
+ obj = cls [ obj .replace (cls .__name__ + "." , "" )]
86
82
elif not isinstance (obj , cls ):
87
83
raise TypeError (
88
84
f"Invalid Weight class provided; expected { cls .__name__ } but received { obj .__class__ .__name__ } ."
@@ -95,12 +91,17 @@ def get_state_dict(self, progress: bool) -> Mapping[str, Any]:
95
91
def __repr__ (self ) -> str :
96
92
return f"{ self .__class__ .__name__ } .{ self ._name_ } "
97
93
98
- def __getattr__ (self , name ):
99
- # Be able to fetch Weights attributes directly
100
- for f in fields (Weights ):
101
- if f .name == name :
102
- return object .__getattribute__ (self .value , name )
103
- return super ().__getattr__ (name )
94
+ @property
95
+ def url (self ):
96
+ return self .value .url
97
+
98
+ @property
99
+ def transforms (self ):
100
+ return self .value .transforms
101
+
102
+ @property
103
+ def meta (self ):
104
+ return self .value .meta
104
105
105
106
106
107
def get_weight (name : str ) -> WeightsEnum :
@@ -134,7 +135,7 @@ def get_weight(name: str) -> WeightsEnum:
134
135
if weights_enum is None :
135
136
raise ValueError (f"The weight enum '{ enum_name } ' for the specific method couldn't be retrieved." )
136
137
137
- return weights_enum . from_str ( value_name )
138
+ return weights_enum [ value_name ]
138
139
139
140
140
141
def get_model_weights (name : Union [Callable , str ]) -> WeightsEnum :
0 commit comments