Skip to content

Commit cf9ee41

Browse files
Vincent Moensfacebook-github-bot
Vincent Moens
authored andcommitted
[fbsync] [ViT] Graduate ViT from prototype (#5173)
Summary: * graduate vit from prototype * nit * add vit to docs and hubconf * ufmt * re-correct ufmt * again * fix linter Reviewed By: NicolasHug Differential Revision: D33618174 fbshipit-source-id: 3a1c6d0915d59069b27ff96a982a337ba9d7690a
1 parent d2ae51d commit cf9ee41

File tree

5 files changed

+464
-332
lines changed

5 files changed

+464
-332
lines changed

docs/source/models.rst

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ architectures for image classification:
4040
- `MNASNet`_
4141
- `EfficientNet`_
4242
- `RegNet`_
43+
- `VisionTransformer`_
4344

4445
You can construct a model with random weights by calling its constructor:
4546

@@ -82,6 +83,10 @@ You can construct a model with random weights by calling its constructor:
8283
regnet_x_8gf = models.regnet_x_8gf()
8384
regnet_x_16gf = models.regnet_x_16gf()
8485
regnet_x_32gf = models.regnet_x_32gf()
86+
vit_b_16 = models.vit_b_16()
87+
vit_b_32 = models.vit_b_32()
88+
vit_l_16 = models.vit_l_16()
89+
vit_l_32 = models.vit_l_32()
8590
8691
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
8792
These can be constructed by passing ``pretrained=True``:
@@ -125,6 +130,10 @@ These can be constructed by passing ``pretrained=True``:
125130
regnet_x_8gf = models.regnet_x_8gf(pretrained=True)
126131
regnet_x_16gf = models.regnet_x_16gf(pretrainedTrue)
127132
regnet_x_32gf = models.regnet_x_32gf(pretrained=True)
133+
vit_b_16 = models.vit_b_16(pretrained=True)
134+
vit_b_32 = models.vit_b_32(pretrained=True)
135+
vit_l_16 = models.vit_l_16(pretrained=True)
136+
vit_l_32 = models.vit_l_32(pretrained=True)
128137
129138
Instancing a pre-trained model will download its weights to a cache directory.
130139
This directory can be set using the `TORCH_HOME` environment variable. See
@@ -233,6 +242,10 @@ regnet_y_3_2gf 78.948 94.576
233242
regnet_y_8gf 80.032 95.048
234243
regnet_y_16gf 80.424 95.240
235244
regnet_y_32gf 80.878 95.340
245+
vit_b_16 81.072 95.318
246+
vit_b_32 75.912 92.466
247+
vit_l_16 79.662 94.638
248+
vit_l_32 76.972 93.070
236249
================================ ============= =============
237250

238251

@@ -250,6 +263,7 @@ regnet_y_32gf 80.878 95.340
250263
.. _MNASNet: https://arxiv.org/abs/1807.11626
251264
.. _EfficientNet: https://arxiv.org/abs/1905.11946
252265
.. _RegNet: https://arxiv.org/abs/2003.13678
266+
.. _VisionTransformer: https://arxiv.org/abs/2010.11929
253267

254268
.. currentmodule:: torchvision.models
255269

@@ -433,6 +447,18 @@ RegNet
433447
regnet_x_16gf
434448
regnet_x_32gf
435449

450+
VisionTransformer
451+
-----------------
452+
453+
.. autosummary::
454+
:toctree: generated/
455+
:template: function.rst
456+
457+
vit_b_16
458+
vit_b_32
459+
vit_l_16
460+
vit_l_32
461+
436462
Quantized Models
437463
----------------
438464

hubconf.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Optional list of dependencies required by the package
22
dependencies = ["torch"]
33

4-
# classification
54
from torchvision.models.alexnet import alexnet
65
from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161
76
from torchvision.models.efficientnet import (
@@ -47,8 +46,6 @@
4746
wide_resnet50_2,
4847
wide_resnet101_2,
4948
)
50-
51-
# segmentation
5249
from torchvision.models.segmentation import (
5350
fcn_resnet50,
5451
fcn_resnet101,
@@ -60,3 +57,9 @@
6057
from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0
6158
from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1
6259
from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
60+
from torchvision.models.vision_transformer import (
61+
vit_b_16,
62+
vit_b_32,
63+
vit_l_16,
64+
vit_l_32,
65+
)

torchvision/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .shufflenetv2 import *
1111
from .efficientnet import *
1212
from .regnet import *
13+
from .vision_transformer import *
1314
from . import detection
1415
from . import feature_extraction
1516
from . import optical_flow

0 commit comments

Comments
 (0)