8
8
from ._meta import _IMAGENET_CATEGORIES
9
9
10
10
11
- __all__ = ["ResNet" , "ResNet50Weights" , "resnet50" ]
11
+ __all__ = [
12
+ "ResNet" ,
13
+ "ResNet18Weights" ,
14
+ "ResNet34Weights" ,
15
+ "ResNet50Weights" ,
16
+ "ResNet101Weights" ,
17
+ "ResNet152Weights" ,
18
+ "ResNeXt50_32x4dWeights" ,
19
+ "ResNeXt101_32x8dWeights" ,
20
+ "WideResNet50_2Weights" ,
21
+ "WideResNet101_2Weights" ,
22
+ "resnet18" ,
23
+ "resnet34" ,
24
+ "resnet50" ,
25
+ "resnet101" ,
26
+ "resnet152" ,
27
+ "resnext50_32x4d" ,
28
+ "resnext101_32x8d" ,
29
+ "wide_resnet50_2" ,
30
+ "wide_resnet101_2" ,
31
+ ]
12
32
13
33
14
34
def _resnet (
@@ -35,6 +55,32 @@ def _resnet(
35
55
}
36
56
37
57
58
+ class ResNet18Weights (Weights ):
59
+ ImageNet1K_RefV1 = WeightEntry (
60
+ url = "https://download.pytorch.org/models/resnet18-f37072fd.pth" ,
61
+ transforms = partial (ImageNetEval , crop_size = 224 ),
62
+ meta = {
63
+ ** _common_meta ,
64
+ "recipe" : "" ,
65
+ "acc@1" : 69.758 ,
66
+ "acc@5" : 89.078 ,
67
+ },
68
+ )
69
+
70
+
71
+ class ResNet34Weights (Weights ):
72
+ ImageNet1K_RefV1 = WeightEntry (
73
+ url = "https://download.pytorch.org/models/resnet34-b627a593.pth" ,
74
+ transforms = partial (ImageNetEval , crop_size = 224 ),
75
+ meta = {
76
+ ** _common_meta ,
77
+ "recipe" : "" ,
78
+ "acc@1" : 73.314 ,
79
+ "acc@5" : 91.420 ,
80
+ },
81
+ )
82
+
83
+
38
84
class ResNet50Weights (Weights ):
39
85
ImageNet1K_RefV1 = WeightEntry (
40
86
url = "https://download.pytorch.org/models/resnet50-0676ba61.pth" ,
@@ -58,10 +104,170 @@ class ResNet50Weights(Weights):
58
104
)
59
105
60
106
107
+ class ResNet101Weights (Weights ):
108
+ ImageNet1K_RefV1 = WeightEntry (
109
+ url = "https://download.pytorch.org/models/resnet101-63fe2227.pth" ,
110
+ transforms = partial (ImageNetEval , crop_size = 224 ),
111
+ meta = {
112
+ ** _common_meta ,
113
+ "recipe" : "" ,
114
+ "acc@1" : 77.374 ,
115
+ "acc@5" : 93.546 ,
116
+ },
117
+ )
118
+
119
+
120
+ class ResNet152Weights (Weights ):
121
+ ImageNet1K_RefV1 = WeightEntry (
122
+ url = "https://download.pytorch.org/models/resnet152-394f9c45.pth" ,
123
+ transforms = partial (ImageNetEval , crop_size = 224 ),
124
+ meta = {
125
+ ** _common_meta ,
126
+ "recipe" : "" ,
127
+ "acc@1" : 78.312 ,
128
+ "acc@5" : 94.046 ,
129
+ },
130
+ )
131
+
132
+
133
+ class ResNeXt50_32x4dWeights (Weights ):
134
+ ImageNet1K_RefV1 = WeightEntry (
135
+ url = "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth" ,
136
+ transforms = partial (ImageNetEval , crop_size = 224 ),
137
+ meta = {
138
+ ** _common_meta ,
139
+ "recipe" : "" ,
140
+ "acc@1" : 77.618 ,
141
+ "acc@5" : 93.698 ,
142
+ },
143
+ )
144
+
145
+
146
+ class ResNeXt101_32x8dWeights (Weights ):
147
+ ImageNet1K_RefV1 = WeightEntry (
148
+ url = "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth" ,
149
+ transforms = partial (ImageNetEval , crop_size = 224 ),
150
+ meta = {
151
+ ** _common_meta ,
152
+ "recipe" : "" ,
153
+ "acc@1" : 79.312 ,
154
+ "acc@5" : 94.526 ,
155
+ },
156
+ )
157
+
158
+
159
+ class WideResNet50_2Weights (Weights ):
160
+ ImageNet1K_RefV1 = WeightEntry (
161
+ url = "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth" ,
162
+ transforms = partial (ImageNetEval , crop_size = 224 ),
163
+ meta = {
164
+ ** _common_meta ,
165
+ "recipe" : "" ,
166
+ "acc@1" : 78.468 ,
167
+ "acc@5" : 94.086 ,
168
+ },
169
+ )
170
+
171
+
172
+ class WideResNet101_2Weights (Weights ):
173
+ ImageNet1K_RefV1 = WeightEntry (
174
+ url = "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth" ,
175
+ transforms = partial (ImageNetEval , crop_size = 224 ),
176
+ meta = {
177
+ ** _common_meta ,
178
+ "recipe" : "" ,
179
+ "acc@1" : 78.848 ,
180
+ "acc@5" : 94.284 ,
181
+ },
182
+ )
183
+
184
+
185
+ def resnet18 (weights : Optional [ResNet18Weights ] = None , progress : bool = True , ** kwargs : Any ) -> ResNet :
186
+ if "pretrained" in kwargs :
187
+ warnings .warn ("The argument pretrained is deprecated, please use weights instead." )
188
+ weights = ResNet18Weights .ImageNet1K_RefV1 if kwargs .pop ("pretrained" ) else None
189
+
190
+ weights = ResNet18Weights .verify (weights )
191
+
192
+ return _resnet (BasicBlock , [2 , 2 , 2 , 2 ], weights , progress , ** kwargs )
193
+
194
+
195
+ def resnet34 (weights : Optional [ResNet34Weights ] = None , progress : bool = True , ** kwargs : Any ) -> ResNet :
196
+ if "pretrained" in kwargs :
197
+ warnings .warn ("The argument pretrained is deprecated, please use weights instead." )
198
+ weights = ResNet34Weights .ImageNet1K_RefV1 if kwargs .pop ("pretrained" ) else None
199
+
200
+ weights = ResNet34Weights .verify (weights )
201
+
202
+ return _resnet (BasicBlock , [3 , 4 , 6 , 3 ], weights , progress , ** kwargs )
203
+
204
+
61
205
def resnet50 (weights : Optional [ResNet50Weights ] = None , progress : bool = True , ** kwargs : Any ) -> ResNet :
62
206
if "pretrained" in kwargs :
63
207
warnings .warn ("The argument pretrained is deprecated, please use weights instead." )
64
208
weights = ResNet50Weights .ImageNet1K_RefV1 if kwargs .pop ("pretrained" ) else None
65
209
weights = ResNet50Weights .verify (weights )
66
210
67
211
return _resnet (Bottleneck , [3 , 4 , 6 , 3 ], weights , progress , ** kwargs )
212
+
213
+
214
+ def resnet101 (weights : Optional [ResNet101Weights ] = None , progress : bool = True , ** kwargs : Any ) -> ResNet :
215
+ if "pretrained" in kwargs :
216
+ warnings .warn ("The argument pretrained is deprecated, please use weights instead." )
217
+ weights = ResNet101Weights .ImageNet1K_RefV1 if kwargs .pop ("pretrained" ) else None
218
+
219
+ weights = ResNet101Weights .verify (weights )
220
+
221
+ return _resnet (Bottleneck , [3 , 4 , 23 , 3 ], weights , progress , ** kwargs )
222
+
223
+
224
+ def resnet152 (weights : Optional [ResNet152Weights ] = None , progress : bool = True , ** kwargs : Any ) -> ResNet :
225
+ if "pretrained" in kwargs :
226
+ warnings .warn ("The argument pretrained is deprecated, please use weights instead." )
227
+ weights = ResNet152Weights .ImageNet1K_RefV1 if kwargs .pop ("pretrained" ) else None
228
+
229
+ weights = ResNet152Weights .verify (weights )
230
+
231
+ return _resnet (Bottleneck , [3 , 8 , 36 , 3 ], weights , progress , ** kwargs )
232
+
233
+
234
+ def resnext50_32x4d (weights : Optional [ResNeXt50_32x4dWeights ] = None , progress : bool = True , ** kwargs : Any ) -> ResNet :
235
+ if "pretrained" in kwargs :
236
+ warnings .warn ("The argument pretrained is deprecated, please use weights instead." )
237
+ weights = ResNeXt50_32x4dWeights .ImageNet1K_RefV1 if kwargs .pop ("pretrained" ) else None
238
+
239
+ weights = ResNeXt50_32x4dWeights .verify (weights )
240
+ kwargs ["groups" ] = 32
241
+ kwargs ["width_per_group" ] = 4
242
+ return _resnet (Bottleneck , [3 , 4 , 6 , 3 ], weights , progress , ** kwargs )
243
+
244
+
245
+ def resnext101_32x8d (weights : Optional [ResNeXt101_32x8dWeights ] = None , progress : bool = True , ** kwargs : Any ) -> ResNet :
246
+ if "pretrained" in kwargs :
247
+ warnings .warn ("The argument pretrained is deprecated, please use weights instead." )
248
+ weights = ResNeXt101_32x8dWeights .ImageNet1K_RefV1 if kwargs .pop ("pretrained" ) else None
249
+
250
+ weights = ResNeXt101_32x8dWeights .verify (weights )
251
+ kwargs ["groups" ] = 32
252
+ kwargs ["width_per_group" ] = 8
253
+ return _resnet (Bottleneck , [3 , 4 , 23 , 3 ], weights , progress , ** kwargs )
254
+
255
+
256
+ def wide_resnet50_2 (weights : Optional [WideResNet50_2Weights ] = None , progress : bool = True , ** kwargs : Any ) -> ResNet :
257
+ if "pretrained" in kwargs :
258
+ warnings .warn ("The argument pretrained is deprecated, please use weights instead." )
259
+ weights = WideResNet50_2Weights .ImageNet1K_RefV1 if kwargs .pop ("pretrained" ) else None
260
+
261
+ weights = WideResNet50_2Weights .verify (weights )
262
+ kwargs ["width_per_group" ] = 64 * 2
263
+ return _resnet (Bottleneck , [3 , 4 , 6 , 3 ], weights , progress , ** kwargs )
264
+
265
+
266
+ def wide_resnet101_2 (weights : Optional [WideResNet101_2Weights ] = None , progress : bool = True , ** kwargs : Any ) -> ResNet :
267
+ if "pretrained" in kwargs :
268
+ warnings .warn ("The argument pretrained is deprecated, please use weights instead." )
269
+ weights = WideResNet101_2Weights .ImageNet1K_RefV1 if kwargs .pop ("pretrained" ) else None
270
+
271
+ weights = WideResNet101_2Weights .verify (weights )
272
+ kwargs ["width_per_group" ] = 64 * 2
273
+ return _resnet (Bottleneck , [3 , 4 , 23 , 3 ], weights , progress , ** kwargs )
0 commit comments