diff --git a/references/classification/README.md b/references/classification/README.md index bae563c31c5..006e9c398b1 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -31,6 +31,17 @@ Here `$MODEL` is one of `alexnet`, `vgg11`, `vgg13`, `vgg16` or `vgg19`. Note that `vgg11_bn`, `vgg13_bn`, `vgg16_bn`, and `vgg19_bn` include batch normalization and thus are trained with the default parameters. +### Inception V3 + +The weights of the Inception V3 model are ported from the original paper rather than trained from scratch. + +Since it expects tensors with a size of N x 3 x 299 x 299, to validate the model use the following command: + +``` +torchrun --nproc_per_node=8 train.py --model inception_v3 + --val-resize-size 342 --val-crop-size 299 --train-crop-size 299 --test-only --pretrained +``` + ### ResNext-50 32x4d ``` torchrun --nproc_per_node=8 train.py\ @@ -79,6 +90,25 @@ The weights of the B0-B4 variants are ported from Ross Wightman's [timm repo](ht The weights of the B5-B7 variants are ported from Luke Melas' [EfficientNet-PyTorch repo](https://github.com/lukemelas/EfficientNet-PyTorch/blob/1039e009545d9329ea026c9f7541341439712b96/efficientnet_pytorch/utils.py#L562-L564). +All models were trained using Bicubic interpolation and each have custom crop and resize sizes. To validate the models use the following commands: +``` +torchrun --nproc_per_node=8 train.py --model efficientnet_b0 --interpolation bicubic\ + --val-resize-size 256 --val-crop-size 224 --train-crop-size 224 --test-only --pretrained +torchrun --nproc_per_node=8 train.py --model efficientnet_b1 --interpolation bicubic\ + --val-resize-size 256 --val-crop-size 240 --train-crop-size 240 --test-only --pretrained +torchrun --nproc_per_node=8 train.py --model efficientnet_b2 --interpolation bicubic\ + --val-resize-size 288 --val-crop-size 288 --train-crop-size 288 --test-only --pretrained +torchrun --nproc_per_node=8 train.py --model efficientnet_b3 --interpolation bicubic\ + --val-resize-size 320 --val-crop-size 300 --train-crop-size 300 --test-only --pretrained +torchrun --nproc_per_node=8 train.py --model efficientnet_b4 --interpolation bicubic\ + --val-resize-size 384 --val-crop-size 380 --train-crop-size 380 --test-only --pretrained +torchrun --nproc_per_node=8 train.py --model efficientnet_b5 --interpolation bicubic\ + --val-resize-size 456 --val-crop-size 456 --train-crop-size 456 --test-only --pretrained +torchrun --nproc_per_node=8 train.py --model efficientnet_b6 --interpolation bicubic\ + --val-resize-size 528 --val-crop-size 528 --train-crop-size 528 --test-only --pretrained +torchrun --nproc_per_node=8 train.py --model efficientnet_b7 --interpolation bicubic\ + --val-resize-size 600 --val-crop-size 600 --train-crop-size 600 --test-only --pretrained +``` ### RegNet @@ -181,3 +211,8 @@ For post training quant, device is set to CPU. For training, the device is set t ``` python train_quantization.py --device='cpu' --test-only --backend='' --model='' ``` + +For inception_v3 you need to pass the following extra parameters: +``` +--val-resize-size 342 --val-crop-size 299 --train-crop-size 299 +``` diff --git a/references/classification/presets.py b/references/classification/presets.py index 27ce486207d..6e1000174ab 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -9,21 +9,22 @@ def __init__( crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), + interpolation=InterpolationMode.BILINEAR, hflip_prob=0.5, auto_augment_policy=None, random_erase_prob=0.0, ): - trans = [transforms.RandomResizedCrop(crop_size)] + trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] if hflip_prob > 0: trans.append(transforms.RandomHorizontalFlip(hflip_prob)) if auto_augment_policy is not None: if auto_augment_policy == "ra": - trans.append(autoaugment.RandAugment()) + trans.append(autoaugment.RandAugment(interpolation=interpolation)) elif auto_augment_policy == "ta_wide": - trans.append(autoaugment.TrivialAugmentWide()) + trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation)) else: aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) - trans.append(autoaugment.AutoAugment(policy=aa_policy)) + trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation)) trans.extend( [ transforms.PILToTensor(), diff --git a/references/classification/train.py b/references/classification/train.py index 7addad30350..bae6adea63a 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -14,22 +14,20 @@ from torchvision.transforms.functional import InterpolationMode -def train_one_epoch( - model, criterion, optimizer, data_loader, device, epoch, print_freq, amp=False, model_ema=None, scaler=None -): +def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}")) header = "Epoch: [{}]".format(epoch) - for image, target in metric_logger.log_every(data_loader, print_freq, header): + for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): start_time = time.time() image, target = image.to(device), target.to(device) output = model(image) optimizer.zero_grad() - if amp: + if args.amp: with torch.cuda.amp.autocast(): loss = criterion(output, target) scaler.scale(loss).backward() @@ -40,6 +38,12 @@ def train_one_epoch( loss.backward() optimizer.step() + if model_ema and i % args.model_ema_steps == 0: + model_ema.update_parameters(model) + if epoch < args.lr_warmup_epochs: + # Reset ema buffer to keep copying weights during warmup period + model_ema.n_averaged.fill_(0) + acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) batch_size = image.shape[0] metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) @@ -47,9 +51,6 @@ def train_one_epoch( metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time)) - if model_ema: - model_ema.update_parameters(model) - def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""): model.eval() @@ -106,24 +107,8 @@ def _get_cache_path(filepath): def load_data(traindir, valdir, args): # Data loading code print("Loading data") - resize_size, crop_size = 256, 224 - interpolation = InterpolationMode.BILINEAR - if args.model == "inception_v3": - resize_size, crop_size = 342, 299 - elif args.model.startswith("efficientnet_"): - sizes = { - "b0": (256, 224), - "b1": (256, 240), - "b2": (288, 288), - "b3": (320, 300), - "b4": (384, 380), - "b5": (456, 456), - "b6": (528, 528), - "b7": (600, 600), - } - e_type = args.model.replace("efficientnet_", "") - resize_size, crop_size = sizes[e_type] - interpolation = InterpolationMode.BICUBIC + val_resize_size, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size + interpolation = InterpolationMode(args.interpolation) print("Loading training data") st = time.time() @@ -138,7 +123,10 @@ def load_data(traindir, valdir, args): dataset = torchvision.datasets.ImageFolder( traindir, presets.ClassificationPresetTrain( - crop_size=crop_size, auto_augment_policy=auto_augment_policy, random_erase_prob=random_erase_prob + crop_size=train_crop_size, + interpolation=interpolation, + auto_augment_policy=auto_augment_policy, + random_erase_prob=random_erase_prob, ), ) if args.cache_dataset: @@ -156,7 +144,9 @@ def load_data(traindir, valdir, args): else: dataset_test = torchvision.datasets.ImageFolder( valdir, - presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size, interpolation=interpolation), + presets.ClassificationPresetEval( + crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation + ), ) if args.cache_dataset: print("Saving dataset_test to {}".format(cache_path)) @@ -224,10 +214,17 @@ def main(args): criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) + if args.norm_weight_decay is None: + parameters = model.parameters() + else: + param_groups = torchvision.ops._utils.split_normalization_params(model) + wd_groups = [args.norm_weight_decay, args.weight_decay] + parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p] + opt_name = args.opt.lower() if opt_name.startswith("sgd"): optimizer = torch.optim.SGD( - model.parameters(), + parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, @@ -235,15 +232,12 @@ def main(args): ) elif opt_name == "rmsprop": optimizer = torch.optim.RMSprop( - model.parameters(), - lr=args.lr, - momentum=args.momentum, - weight_decay=args.weight_decay, - eps=0.0316, - alpha=0.9, + parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9 ) + elif opt_name == "adamw": + optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay) else: - raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt)) + raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.") scaler = torch.cuda.amp.GradScaler() if args.amp else None @@ -288,13 +282,23 @@ def main(args): model_ema = None if args.model_ema: - model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=args.model_ema_decay) + # Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at: + # https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123 + # + # total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps) + # We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus: + # adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs + adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs + alpha = 1.0 - args.model_ema_decay + alpha = min(1.0, alpha * adjust) + model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha) if args.resume: checkpoint = torch.load(args.resume, map_location="cpu") model_without_ddp.load_state_dict(checkpoint["model"]) - optimizer.load_state_dict(checkpoint["optimizer"]) - lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + if not args.test_only: + optimizer.load_state_dict(checkpoint["optimizer"]) + lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) args.start_epoch = checkpoint["epoch"] + 1 if model_ema: model_ema.load_state_dict(checkpoint["model_ema"]) @@ -303,8 +307,10 @@ def main(args): # We disable the cudnn benchmarking because it can noticeably affect the accuracy torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True - - evaluate(model, criterion, data_loader_test, device=device) + if model_ema: + evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA") + else: + evaluate(model, criterion, data_loader_test, device=device) return print("Start training") @@ -312,9 +318,7 @@ def main(args): for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) - train_one_epoch( - model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.amp, model_ema, scaler - ) + train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler) lr_scheduler.step() evaluate(model, criterion, data_loader_test, device=device) if model_ema: @@ -362,6 +366,12 @@ def get_args_parser(add_help=True): help="weight decay (default: 1e-4)", dest="weight_decay", ) + parser.add_argument( + "--norm-weight-decay", + default=None, + type=float, + help="weight decay for Normalization layers (default: None, same value as --wd)", + ) parser.add_argument( "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing" ) @@ -415,15 +425,33 @@ def get_args_parser(add_help=True): parser.add_argument( "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters" ) + parser.add_argument( + "--model-ema-steps", + type=int, + default=32, + help="the number of iterations that controls how often to update the EMA model (default: 32)", + ) parser.add_argument( "--model-ema-decay", type=float, - default=0.9, - help="decay factor for Exponential Moving Average of model parameters(default: 0.9)", + default=0.99998, + help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)", ) parser.add_argument( "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only." ) + parser.add_argument( + "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)" + ) + parser.add_argument( + "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)" + ) + parser.add_argument( + "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)" + ) + parser.add_argument( + "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)" + ) return parser diff --git a/references/classification/train_quantization.py b/references/classification/train_quantization.py index ae4e81b0133..f384be76a62 100644 --- a/references/classification/train_quantization.py +++ b/references/classification/train_quantization.py @@ -236,6 +236,19 @@ def get_args_parser(add_help=True): parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training") + parser.add_argument( + "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)" + ) + parser.add_argument( + "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)" + ) + parser.add_argument( + "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)" + ) + parser.add_argument( + "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)" + ) + return parser diff --git a/references/classification/utils.py b/references/classification/utils.py index 9e043582d13..1a4adc7f60f 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -380,6 +380,9 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T # Load the weights to the model to validate that everything works # and remove unnecessary weights (such as auxiliaries, etc) + if checkpoint_key == "model_ema": + del checkpoint[checkpoint_key]["n_averaged"] + torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(checkpoint[checkpoint_key], "module.") model.load_state_dict(checkpoint[checkpoint_key], strict=strict) tmp_path = os.path.join(output_dir, str(model.__hash__())) diff --git a/test/test_ops.py b/test/test_ops.py index 64329936b72..892496dffca 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -9,10 +9,10 @@ import torch from common_utils import needs_cuda, cpu_and_gpu, assert_equal from PIL import Image -from torch import Tensor +from torch import nn, Tensor from torch.autograd import gradcheck from torch.nn.modules.utils import _pair -from torchvision import ops +from torchvision import models, ops class RoIOpTester(ABC): @@ -1176,5 +1176,15 @@ def test_stochastic_depth(self, mode, p): assert p_value > 0.0001 +class TestUtils: + @pytest.mark.parametrize("norm_layer", [None, nn.BatchNorm2d, nn.LayerNorm]) + def test_split_normalization_params(self, norm_layer): + model = models.mobilenet_v3_large(norm_layer=norm_layer) + params = ops._utils.split_normalization_params(model, None if norm_layer is None else [norm_layer]) + + assert len(params[0]) == 92 + assert len(params[1]) == 82 + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/torchvision/ops/_utils.py b/torchvision/ops/_utils.py index 86dfce46509..3a07c747f58 100644 --- a/torchvision/ops/_utils.py +++ b/torchvision/ops/_utils.py @@ -1,7 +1,7 @@ -from typing import List, Union +from typing import List, Optional, Tuple, Union import torch -from torch import Tensor +from torch import nn, Tensor def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor: @@ -36,3 +36,28 @@ def check_roi_boxes_shape(boxes: Union[Tensor, List[Tensor]]): else: assert False, "boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]" return + + +def split_normalization_params( + model: nn.Module, norm_classes: Optional[List[type]] = None +) -> Tuple[List[Tensor], List[Tensor]]: + # Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501 + if not norm_classes: + norm_classes = [nn.modules.batchnorm._BatchNorm, nn.LayerNorm, nn.GroupNorm] + + for t in norm_classes: + if not issubclass(t, nn.Module): + raise ValueError(f"Class {t} is not a subclass of nn.Module.") + + classes = tuple(norm_classes) + + norm_params = [] + other_params = [] + for module in model.modules(): + if next(module.children(), None): + other_params.extend(p for p in module.parameters(recurse=False) if p.requires_grad) + elif isinstance(module, classes): + norm_params.extend(p for p in module.parameters() if p.requires_grad) + else: + other_params.extend(p for p in module.parameters() if p.requires_grad) + return norm_params, other_params