Skip to content

Commit e288f6c

Browse files
xiaohu2015datumboxjdsgomes
authored
Adding Swin Transformer architecture (#5491)
* add swin transformer * Update swin_transformer.py * Update swin_transformer.py * fix lint * fix lint * refactor code * add swin_transformer * Update swin_transformer.py * fix bug * refactor code * fix lint * update init_weights * move shift_window into attention * refactor code * fix bug * Update swin_transformer.py * Update swin_transformer.py * fix lint * add patch_merge * fix bug * Update swin_transformer.py * Update swin_transformer.py * Update swin_transformer.py * refactor code * Update swin_transformer.py * refactor code * fix lint * refactor code * add swin_tiny * add swin_tiny.pkl * fix lint * Delete ModelTester.test_swin_tiny_expect.pkl * add swin_tiny * add * add Optional to bias * update init weights * update init_weights and add no weight decay * add no weight decay * add set_weight_decay * add set_weight_decay * fix lint * fix lint * add lr_cos_min * add other swin models * Update torchvision/models/swin_transformer.py Co-authored-by: Vasilis Vryniotis <[email protected]> * refactor doc * Update utils.py * Update train.py * Update train.py * Update swin_transformer.py * update model builder * fix lint * add * Update torchvision/models/swin_transformer.py Co-authored-by: Vasilis Vryniotis <[email protected]> * Update torchvision/models/swin_transformer.py Co-authored-by: Vasilis Vryniotis <[email protected]> * update other model * simplify the model name just like ViT * add lr_cos_min * fix lint * fix lint * Update swin_transformer.py * Update swin_transformer.py * Update swin_transformer.py * Delete ModelTester.test_swin_tiny_expect.pkl * add swin_t * refactor code * Update train.py * add swin_s * ignore a error of mypy * Update swin_transformer.py * fix lint * add swin_b * add swin_l * refactor code * Update train.py * move relative_position_bias to __init__ * fix formatting * Revert "fix formatting" This reverts commit 41faba2. * Revert "move relative_position_bias to __init__" This reverts commit f061544. * refactor code * Remove deprecated meta-data from `_COMMON_META` * fix linter * add pretrained weights for swin_t * fix format * apply ufmt * add documentation * update references README * adding new style docs * update pre-trained weights values * remove other variants * fix typo * Remove expect for the variants not yet supported Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Joao Gomes <[email protected]>
1 parent bb1ab47 commit e288f6c

File tree

8 files changed

+517
-2
lines changed

8 files changed

+517
-2
lines changed

docs/source/models.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ architectures for image classification:
4242
- `RegNet`_
4343
- `VisionTransformer`_
4444
- `ConvNeXt`_
45+
- `SwinTransformer`_
4546

4647
You can construct a model with random weights by calling its constructor:
4748

@@ -97,6 +98,7 @@ You can construct a model with random weights by calling its constructor:
9798
convnext_small = models.convnext_small()
9899
convnext_base = models.convnext_base()
99100
convnext_large = models.convnext_large()
101+
swin_t = models.swin_t()
100102
101103
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
102104

@@ -219,6 +221,7 @@ convnext_tiny 82.520 96.146
219221
convnext_small 83.616 96.650
220222
convnext_base 84.062 96.870
221223
convnext_large 84.414 96.976
224+
swin_t 81.358 95.526
222225
================================ ============= =============
223226

224227

@@ -238,6 +241,7 @@ convnext_large 84.414 96.976
238241
.. _RegNet: https://arxiv.org/abs/2003.13678
239242
.. _VisionTransformer: https://arxiv.org/abs/2010.11929
240243
.. _ConvNeXt: https://arxiv.org/abs/2201.03545
244+
.. _SwinTransformer: https://arxiv.org/abs/2103.14030
241245

242246
.. currentmodule:: torchvision.models
243247

@@ -450,6 +454,15 @@ ConvNeXt
450454
convnext_base
451455
convnext_large
452456

457+
SwinTransformer
458+
--------
459+
460+
.. autosummary::
461+
:toctree: generated/
462+
:template: function.rst
463+
464+
swin_t
465+
453466
Quantized Models
454467
----------------
455468

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
SwinTransformer
2+
===============
3+
4+
.. currentmodule:: torchvision.models
5+
6+
The SwinTransformer model is based on the `Swin Transformer: Hierarchical Vision
7+
Transformer using Shifted Windows <https://arxiv.org/abs/2103.14030>`__
8+
paper.
9+
10+
11+
Model builders
12+
--------------
13+
14+
The following model builders can be used to instanciate an SwinTransformer model.
15+
`swin_t` can be instantiated with pre-trained weights and all others without.
16+
All the model builders internally rely on the ``torchvision.models.swin_transformer.SwinTransformer``
17+
base class. Please refer to the `source code
18+
<https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_ for
19+
more details about this class.
20+
21+
.. autosummary::
22+
:toctree: generated/
23+
:template: function.rst
24+
25+
swin_t

docs/source/models_new.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ weights:
4646
models/resnet
4747
models/resnext
4848
models/squeezenet
49+
models/swin_transformer
4950
models/vgg
5051
models/vision_transformer
5152

references/classification/README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,18 @@ Note that the above command corresponds to training on a single node with 8 GPUs
224224
For generatring the pre-trained weights, we trained with 2 nodes, each with 8 GPUs (for a total of 16 GPUs),
225225
and `--batch_size 64`.
226226

227+
228+
### SwinTransformer
229+
```
230+
torchrun --nproc_per_node=8 train.py\
231+
--model swin_t --epochs 300 --batch-size 128 --opt adamw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0\
232+
--bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr --lr-min 0.00001 --lr-warmup-method linear\
233+
--lr-warmup-epochs 20 --lr-warmup-decay 0.01 --amp --label-smoothing 0.1 --mixup-alpha 0.8\
234+
--clip-grad-norm 5.0 --cutmix-alpha 1.0 --random-erase 0.25 --interpolation bicubic --auto-augment ra
235+
```
236+
Note that `--val-resize-size` was optimized in a post-training step, see their `Weights` entry for the exact value.
237+
238+
227239
## Mixed precision training
228240
Automatic Mixed Precision (AMP) training on GPU for Pytorch can be enabled with the [torch.cuda.amp](https://pytorch.org/docs/stable/amp.html?highlight=amp#module-torch.cuda.amp).
229241

references/classification/train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def main(args):
233233
if args.bias_weight_decay is not None:
234234
custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
235235
if args.transformer_embedding_decay is not None:
236-
for key in ["class_token", "position_embedding", "relative_position_bias"]:
236+
for key in ["class_token", "position_embedding", "relative_position_bias_table"]:
237237
custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
238238
parameters = utils.set_weight_decay(
239239
model,
@@ -267,7 +267,7 @@ def main(args):
267267
main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
268268
elif args.lr_scheduler == "cosineannealinglr":
269269
main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
270-
optimizer, T_max=args.epochs - args.lr_warmup_epochs
270+
optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min
271271
)
272272
elif args.lr_scheduler == "exponentiallr":
273273
main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
@@ -424,6 +424,7 @@ def get_args_parser(add_help=True):
424424
parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
425425
parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
426426
parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
427+
parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
427428
parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
428429
parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
429430
parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
939 Bytes
Binary file not shown.

torchvision/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .squeezenet import *
1313
from .vgg import *
1414
from .vision_transformer import *
15+
from .swin_transformer import *
1516
from . import detection
1617
from . import optical_flow
1718
from . import quantization

0 commit comments

Comments
 (0)