Skip to content

Commit 03d1133

Browse files
authored
Adding vit_h_14 architecture (#5210)
* adding vit_h_14 * prototype and docs * bug fix * adding curl check
1 parent abc6c77 commit 03d1133

File tree

5 files changed

+49
-0
lines changed

5 files changed

+49
-0
lines changed

docs/source/models.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ You can construct a model with random weights by calling its constructor:
8888
vit_b_32 = models.vit_b_32()
8989
vit_l_16 = models.vit_l_16()
9090
vit_l_32 = models.vit_l_32()
91+
vit_h_14 = models.vit_h_14()
9192
9293
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
9394
These can be constructed by passing ``pretrained=True``:
@@ -460,6 +461,7 @@ VisionTransformer
460461
vit_b_32
461462
vit_l_16
462463
vit_l_32
464+
vit_h_14
463465

464466
Quantized Models
465467
----------------

hubconf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,5 @@
6363
vit_b_32,
6464
vit_l_16,
6565
vit_l_32,
66+
vit_h_14,
6667
)
939 Bytes
Binary file not shown.

torchvision/models/vision_transformer.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"vit_b_32",
1616
"vit_l_16",
1717
"vit_l_32",
18+
"vit_h_14",
1819
]
1920

2021
model_urls = {
@@ -260,6 +261,8 @@ def _vision_transformer(
260261
)
261262

262263
if pretrained:
264+
if arch not in model_urls:
265+
raise ValueError(f"No checkpoint is available for model type '{arch}'!")
263266
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
264267
model.load_state_dict(state_dict)
265268

@@ -354,6 +357,26 @@ def vit_l_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
354357
)
355358

356359

360+
def vit_h_14(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer:
361+
"""
362+
Constructs a vit_h_14 architecture from
363+
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
364+
365+
NOTE: Pretrained weights are not available for this model.
366+
"""
367+
return _vision_transformer(
368+
arch="vit_h_14",
369+
patch_size=14,
370+
num_layers=32,
371+
num_heads=16,
372+
hidden_dim=1280,
373+
mlp_dim=5120,
374+
pretrained=pretrained,
375+
progress=progress,
376+
**kwargs,
377+
)
378+
379+
357380
def interpolate_embeddings(
358381
image_size: int,
359382
patch_size: int,

torchvision/prototype/models/vision_transformer.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
"ViT_B_32_Weights",
2020
"ViT_L_16_Weights",
2121
"ViT_L_32_Weights",
22+
"ViT_H_14_Weights",
2223
"vit_b_16",
2324
"vit_b_32",
2425
"vit_l_16",
2526
"vit_l_32",
27+
"vit_h_14",
2628
]
2729

2830

@@ -99,6 +101,11 @@ class ViT_L_32_Weights(WeightsEnum):
99101
default = ImageNet1K_V1
100102

101103

104+
class ViT_H_14_Weights(WeightsEnum):
105+
# Weights are not available yet.
106+
pass
107+
108+
102109
def _vision_transformer(
103110
patch_size: int,
104111
num_layers: int,
@@ -192,3 +199,19 @@ def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = Tru
192199
progress=progress,
193200
**kwargs,
194201
)
202+
203+
204+
@handle_legacy_interface(weights=("pretrained", None))
205+
def vit_h_14(*, weights: Optional[ViT_H_14_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
206+
weights = ViT_H_14_Weights.verify(weights)
207+
208+
return _vision_transformer(
209+
patch_size=14,
210+
num_layers=32,
211+
num_heads=16,
212+
hidden_dim=1280,
213+
mlp_dim=5120,
214+
weights=weights,
215+
progress=progress,
216+
**kwargs,
217+
)

0 commit comments

Comments
 (0)