10
10
import torch
11
11
import torch .nn as nn
12
12
from torch import Tensor
13
- from torchvision .prototype .transforms import ImageNetEval
14
- from torchvision .transforms .functional import InterpolationMode
15
13
16
14
from ...utils import _log_api_usage_once
17
- from ._api import WeightsEnum , Weights
18
- from ._meta import _IMAGENET_CATEGORIES
15
+ from ._api import WeightsEnum
19
16
from ._utils import handle_legacy_interface
20
17
18
+
21
19
__all__ = [
22
20
"VisionTransformer" ,
23
21
"ViT_B_16_Weights" ,
@@ -235,70 +233,24 @@ def forward(self, x: torch.Tensor):
235
233
return x
236
234
237
235
238
- _COMMON_META = {
239
- "categories" : _IMAGENET_CATEGORIES ,
240
- "interpolation" : InterpolationMode .BILINEAR ,
241
- }
242
-
243
-
244
236
class ViT_B_16_Weights (WeightsEnum ):
245
- ImageNet1K_V1 = Weights (
246
- url = "https://download.pytorch.org/models/vit_b_16-c867db91.pth" ,
247
- transforms = partial (ImageNetEval , crop_size = 224 ),
248
- meta = {
249
- ** _COMMON_META ,
250
- "size" : (224 , 224 ),
251
- "recipe" : "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_16" ,
252
- "acc@1" : 81.072 ,
253
- "acc@5" : 95.318 ,
254
- },
255
- )
256
- default = ImageNet1K_V1
237
+ # If a default model is added here the corresponding changes need to be done in vit_b_16
238
+ pass
257
239
258
240
259
241
class ViT_B_32_Weights (WeightsEnum ):
260
- ImageNet1K_V1 = Weights (
261
- url = "https://download.pytorch.org/models/vit_b_32-d86f8d99.pth" ,
262
- transforms = partial (ImageNetEval , crop_size = 224 ),
263
- meta = {
264
- ** _COMMON_META ,
265
- "size" : (224 , 224 ),
266
- "recipe" : "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_32" ,
267
- "acc@1" : 75.912 ,
268
- "acc@5" : 92.466 ,
269
- },
270
- )
271
- default = ImageNet1K_V1
242
+ # If a default model is added here the corresponding changes need to be done in vit_b_32
243
+ pass
272
244
273
245
274
246
class ViT_L_16_Weights (WeightsEnum ):
275
- ImageNet1K_V1 = Weights (
276
- url = "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth" ,
277
- transforms = partial (ImageNetEval , crop_size = 224 , resize_size = 242 ),
278
- meta = {
279
- ** _COMMON_META ,
280
- "size" : (224 , 224 ),
281
- "recipe" : "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_16" ,
282
- "acc@1" : 79.662 ,
283
- "acc@5" : 94.638 ,
284
- },
285
- )
286
- default = ImageNet1K_V1
247
+ # If a default model is added here the corresponding changes need to be done in vit_l_16
248
+ pass
287
249
288
250
289
251
class ViT_L_32_Weights (WeightsEnum ):
290
- ImageNet1K_V1 = Weights (
291
- url = "https://download.pytorch.org/models/vit_l_32-c7638314.pth" ,
292
- transforms = partial (ImageNetEval , crop_size = 224 ),
293
- meta = {
294
- ** _COMMON_META ,
295
- "size" : (224 , 224 ),
296
- "recipe" : "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_32" ,
297
- "acc@1" : 76.972 ,
298
- "acc@5" : 93.07 ,
299
- },
300
- )
301
- default = ImageNet1K_V1
252
+ # If a default model is added here the corresponding changes need to be done in vit_l_32
253
+ pass
302
254
303
255
304
256
def _vision_transformer (
@@ -329,7 +281,7 @@ def _vision_transformer(
329
281
return model
330
282
331
283
332
- @handle_legacy_interface (weights = ("pretrained" , ViT_B_16_Weights . ImageNet1K_V1 ))
284
+ @handle_legacy_interface (weights = ("pretrained" , None ))
333
285
def vit_b_16 (* , weights : Optional [ViT_B_16_Weights ] = None , progress : bool = True , ** kwargs : Any ) -> VisionTransformer :
334
286
"""
335
287
Constructs a vit_b_16 architecture from
@@ -354,7 +306,7 @@ def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = Tru
354
306
)
355
307
356
308
357
- @handle_legacy_interface (weights = ("pretrained" , ViT_B_32_Weights . ImageNet1K_V1 ))
309
+ @handle_legacy_interface (weights = ("pretrained" , None ))
358
310
def vit_b_32 (* , weights : Optional [ViT_B_32_Weights ] = None , progress : bool = True , ** kwargs : Any ) -> VisionTransformer :
359
311
"""
360
312
Constructs a vit_b_32 architecture from
@@ -379,7 +331,7 @@ def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = Tru
379
331
)
380
332
381
333
382
- @handle_legacy_interface (weights = ("pretrained" , ViT_L_16_Weights . ImageNet1K_V1 ))
334
+ @handle_legacy_interface (weights = ("pretrained" , None ))
383
335
def vit_l_16 (* , weights : Optional [ViT_L_16_Weights ] = None , progress : bool = True , ** kwargs : Any ) -> VisionTransformer :
384
336
"""
385
337
Constructs a vit_l_16 architecture from
@@ -404,7 +356,7 @@ def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = Tru
404
356
)
405
357
406
358
407
- @handle_legacy_interface (weights = ("pretrained" , ViT_L_32_Weights . ImageNet1K_V1 ))
359
+ @handle_legacy_interface (weights = ("pretrained" , None ))
408
360
def vit_l_32 (* , weights : Optional [ViT_L_32_Weights ] = None , progress : bool = True , ** kwargs : Any ) -> VisionTransformer :
409
361
"""
410
362
Constructs a vit_l_32 architecture from
0 commit comments