diff --git a/references/classification/README.md b/references/classification/README.md index a73fde3679f..48b20a30242 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -143,6 +143,60 @@ torchrun --nproc_per_node=8 train.py\ ``` Here `$MODEL` is one of `regnet_x_32gf`, `regnet_y_16gf` and `regnet_y_32gf`. +### Vision Transformer + +#### vit_b_16 +``` +torchrun --nproc_per_node=8 train.py\ + --model vit_b_16 --epochs 300 --batch-size 512 --opt adamw --lr 0.003 --wd 0.3\ + --lr-scheduler cosineannealinglr --lr-warmup-method linear --lr-warmup-epochs 30\ + --lr-warmup-decay 0.033 --amp --label-smoothing 0.11 --mixup-alpha 0.2 --auto-augment ra\ + --clip-grad-norm 1 --ra-sampler --cutmix-alpha 1.0 --model-ema +``` + +Note that the above command corresponds to training on a single node with 8 GPUs. +For generatring the pre-trained weights, we trained with 8 nodes, each with 8 GPUs (for a total of 64 GPUs), +and `--batch_size 64`. + +#### vit_b_32 +``` +torchrun --nproc_per_node=8 train.py\ + --model vit_b_32 --epochs 300 --batch-size 512 --opt adamw --lr 0.003 --wd 0.3\ + --lr-scheduler cosineannealinglr --lr-warmup-method linear --lr-warmup-epochs 30\ + --lr-warmup-decay 0.033 --amp --label-smoothing 0.11 --mixup-alpha 0.2 --auto-augment imagenet\ + --clip-grad-norm 1 --ra-sampler --cutmix-alpha 1.0 --model-ema +``` + +Note that the above command corresponds to training on a single node with 8 GPUs. +For generatring the pre-trained weights, we trained with 2 nodes, each with 8 GPUs (for a total of 16 GPUs), +and `--batch_size 256`. + +#### vit_l_16 +``` +torchrun --nproc_per_node=8 train.py\ + --model vit_l_16 --epochs 600 --batch-size 128 --lr 0.5 --lr-scheduler cosineannealinglr\ + --lr-warmup-method linear --lr-warmup-epochs 5 --label-smoothing 0.1 --mixup-alpha 0.2\ + --auto-augment ta_wide --random-erase 0.1 --weight-decay 0.00002 --norm-weight-decay 0.0\ + --clip-grad-norm 1 --ra-sampler --cutmix-alpha 1.0 --model-ema --val-resize-size 232 +``` + +Note that the above command corresponds to training on a single node with 8 GPUs. +For generatring the pre-trained weights, we trained with 2 nodes, each with 8 GPUs (for a total of 16 GPUs), +and `--batch_size 64`. + +#### vit_l_32 +``` +torchrun --nproc_per_node=8 train.py\ + --model vit_l_32 --epochs 300 --batch-size 512 --opt adamw --lr 0.003 --wd 0.3\ + --lr-scheduler cosineannealinglr --lr-warmup-method linear --lr-warmup-epochs 30\ + --lr-warmup-decay 0.033 --amp --label-smoothing 0.11 --mixup-alpha 0.2 --auto-augment ra\ + --clip-grad-norm 1 --ra-sampler --cutmix-alpha 1.0 --model-ema +``` + +Note that the above command corresponds to training on a single node with 8 GPUs. +For generatring the pre-trained weights, we trained with 8 nodes, each with 8 GPUs (for a total of 64 GPUs), +and `--batch_size 64`. + ## Mixed precision training Automatic Mixed Precision (AMP) training on GPU for Pytorch can be enabled with the [torch.cuda.amp](https://pytorch.org/docs/stable/amp.html?highlight=amp#module-torch.cuda.amp). diff --git a/torchvision/prototype/models/vision_transformer.py b/torchvision/prototype/models/vision_transformer.py index 7dd152c4a0c..a3b0ec8e7e9 100644 --- a/torchvision/prototype/models/vision_transformer.py +++ b/torchvision/prototype/models/vision_transformer.py @@ -10,12 +10,14 @@ import torch import torch.nn as nn from torch import Tensor +from torchvision.prototype.transforms import ImageNetEval +from torchvision.transforms.functional import InterpolationMode from ...utils import _log_api_usage_once -from ._api import WeightsEnum +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface - __all__ = [ "VisionTransformer", "ViT_B_16_Weights", @@ -233,24 +235,70 @@ def forward(self, x: torch.Tensor): return x +_COMMON_META = { + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, +} + + class ViT_B_16_Weights(WeightsEnum): - # If a default model is added here the corresponding changes need to be done in vit_b_16 - pass + ImageNet1K_V1 = Weights( + url="https://download.pytorch.org/models/vit_b_16-c867db91.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_COMMON_META, + "size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_16", + "acc@1": 81.072, + "acc@5": 95.318, + }, + ) + default = ImageNet1K_V1 class ViT_B_32_Weights(WeightsEnum): - # If a default model is added here the corresponding changes need to be done in vit_b_32 - pass + ImageNet1K_V1 = Weights( + url="https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_COMMON_META, + "size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_32", + "acc@1": 75.912, + "acc@5": 92.466, + }, + ) + default = ImageNet1K_V1 class ViT_L_16_Weights(WeightsEnum): - # If a default model is added here the corresponding changes need to be done in vit_l_16 - pass + ImageNet1K_V1 = Weights( + url="https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", + transforms=partial(ImageNetEval, crop_size=224, resize_size=242), + meta={ + **_COMMON_META, + "size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_16", + "acc@1": 79.662, + "acc@5": 94.638, + }, + ) + default = ImageNet1K_V1 class ViT_L_32_Weights(WeightsEnum): - # If a default model is added here the corresponding changes need to be done in vit_l_32 - pass + ImageNet1K_V1 = Weights( + url="https://download.pytorch.org/models/vit_l_32-c7638314.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_COMMON_META, + "size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_32", + "acc@1": 76.972, + "acc@5": 93.07, + }, + ) + default = ImageNet1K_V1 def _vision_transformer( @@ -281,7 +329,7 @@ def _vision_transformer( return model -@handle_legacy_interface(weights=("pretrained", None)) +@handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.ImageNet1K_V1)) def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_b_16 architecture from @@ -306,7 +354,7 @@ def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = Tru ) -@handle_legacy_interface(weights=("pretrained", None)) +@handle_legacy_interface(weights=("pretrained", ViT_B_32_Weights.ImageNet1K_V1)) def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_b_32 architecture from @@ -331,7 +379,7 @@ def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = Tru ) -@handle_legacy_interface(weights=("pretrained", None)) +@handle_legacy_interface(weights=("pretrained", ViT_L_16_Weights.ImageNet1K_V1)) def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_l_16 architecture from @@ -356,7 +404,7 @@ def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = Tru ) -@handle_legacy_interface(weights=("pretrained", None)) +@handle_legacy_interface(weights=("pretrained", ViT_L_32_Weights.ImageNet1K_V1)) def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_l_32 architecture from