From fb815ec7b158c8de6d79aa1e2b6cf6ad849de460 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 14 Sep 2021 14:18:30 +0100 Subject: [PATCH 1/7] Warmup on Classficiation references. --- references/classification/train.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/references/classification/train.py b/references/classification/train.py index a3e4c9ad8e9..640de44a5cc 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -195,7 +195,24 @@ def main(args): opt_level=args.apex_opt_level ) - lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) + args.lr_scheduler = args.lr_scheduler.lower() + if args.lr_scheduler == 'steplr': + main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) + elif args.lr_scheduler == 'cosineannealinglr': + main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) + else: + raise RuntimeError("Invalid lr scheduler '{}'. Only StepLR and CosineAnnealingLR " + "are supported.".format(args.lr_scheduler)) + + if args.lr_warmup_epochs > 0: + lr_scheduler = torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay, + total_iters=args.lr_warmup_epochs), main_lr_scheduler], + milestones=[args.lr_warmup_epochs] + ) + else: + lr_scheduler = main_lr_scheduler model_without_ddp = model if args.distributed: @@ -272,6 +289,9 @@ def get_args_parser(add_help=True): parser.add_argument('--label-smoothing', default=0.0, type=float, help='label smoothing (default: 0.0)', dest='label_smoothing') + parser.add_argument('--lr-scheduler', default="steplr", help='the lr scheduler (default: steplr)') + parser.add_argument('--lr-warmup-epochs', default=0, type=int, help='the number of epochs to warmup (default: 0)') + parser.add_argument('--lr-warmup-decay', default=0.01, type=int, help='the decay for lr') parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs') parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') parser.add_argument('--print-freq', default=10, type=int, help='print frequency') From 3f9d8fa7e9aa09816a9ca69b7cab78a01accc11f Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 14 Sep 2021 14:45:49 +0100 Subject: [PATCH 2/7] Adjust epochs for cosine. --- references/classification/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/references/classification/train.py b/references/classification/train.py index 640de44a5cc..4afc9bc9010 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -199,7 +199,8 @@ def main(args): if args.lr_scheduler == 'steplr': main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) elif args.lr_scheduler == 'cosineannealinglr': - main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) + main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, + T_max=args.epochs - args.lr_warmup_epochs) else: raise RuntimeError("Invalid lr scheduler '{}'. Only StepLR and CosineAnnealingLR " "are supported.".format(args.lr_scheduler)) From 8e940eb2bfdaf98ec46f04a68667966c14e4511b Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 16 Sep 2021 10:44:38 +0100 Subject: [PATCH 3/7] Warmup on Segmentation references. --- references/segmentation/train.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/references/segmentation/train.py b/references/segmentation/train.py index fb6c7eeee15..66b0a54101f 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -133,9 +133,21 @@ def main(args): params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) - lr_scheduler = torch.optim.lr_scheduler.LambdaLR( + iters_per_epoch = len(data_loader) + main_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, - lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9) + lambda x: (1 - x / (iters_per_epoch * (args.epochs - args.lr_warmup_epochs))) ** 0.9) + + if args.lr_warmup_epochs > 0: + warmup_iters = iters_per_epoch * args.lr_warmup_epochs + lr_scheduler = torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=args.lr_warmup_decay, + total_iters=warmup_iters), main_lr_scheduler], + milestones=[warmup_iters] + ) + else: + lr_scheduler = main_lr_scheduler if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') @@ -197,6 +209,8 @@ def get_args_parser(add_help=True): parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay') + parser.add_argument('--lr-warmup-epochs', default=0, type=int, help='the number of epochs to warmup (default: 0)') + parser.add_argument('--lr-warmup-decay', default=0.01, type=int, help='the decay for lr') parser.add_argument('--print-freq', default=10, type=int, help='print frequency') parser.add_argument('--output-dir', default='.', help='path where to save') parser.add_argument('--resume', default='', help='resume from checkpoint') From 5e0cbcf62f6d3dc5485e15c321ad396013bb156d Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 16 Sep 2021 16:10:27 +0100 Subject: [PATCH 4/7] Warmup on Video classification references. --- references/video_classification/scheduler.py | 47 -------------------- references/video_classification/train.py | 34 ++++++++++---- 2 files changed, 26 insertions(+), 55 deletions(-) delete mode 100644 references/video_classification/scheduler.py diff --git a/references/video_classification/scheduler.py b/references/video_classification/scheduler.py deleted file mode 100644 index f0f862d41ad..00000000000 --- a/references/video_classification/scheduler.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch -from bisect import bisect_right - - -class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): - def __init__( - self, - optimizer, - milestones, - gamma=0.1, - warmup_factor=1.0 / 3, - warmup_iters=5, - warmup_method="linear", - last_epoch=-1, - ): - if not milestones == sorted(milestones): - raise ValueError( - "Milestones should be a list of" " increasing integers. Got {}", - milestones, - ) - - if warmup_method not in ("constant", "linear"): - raise ValueError( - "Only 'constant' or 'linear' warmup_method accepted" - "got {}".format(warmup_method) - ) - self.milestones = milestones - self.gamma = gamma - self.warmup_factor = warmup_factor - self.warmup_iters = warmup_iters - self.warmup_method = warmup_method - super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) - - def get_lr(self): - warmup_factor = 1 - if self.last_epoch < self.warmup_iters: - if self.warmup_method == "constant": - warmup_factor = self.warmup_factor - elif self.warmup_method == "linear": - alpha = float(self.last_epoch) / self.warmup_iters - warmup_factor = self.warmup_factor * (1 - alpha) + alpha - return [ - base_lr * - warmup_factor * - self.gamma ** bisect_right(self.milestones, self.last_epoch) - for base_lr in self.base_lrs - ] diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 11ac2d5378d..41046ab56a1 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -12,8 +12,6 @@ import presets import utils -from scheduler import WarmupMultiStepLR - try: from apex import amp except ImportError: @@ -202,11 +200,29 @@ def main(args): # convert scheduler to be per iteration, not per epoch, for warmup that lasts # between different epochs - warmup_iters = args.lr_warmup_epochs * len(data_loader) - lr_milestones = [len(data_loader) * m for m in args.lr_milestones] - lr_scheduler = WarmupMultiStepLR( - optimizer, milestones=lr_milestones, gamma=args.lr_gamma, - warmup_iters=warmup_iters, warmup_factor=1e-5) + iters_per_epoch = len(data_loader) + lr_milestones = [iters_per_epoch * (m - args.lr_warmup_epochs) for m in args.lr_milestones] + main_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_milestones, gamma=args.lr_gamma) + + if args.lr_warmup_epochs > 0: + warmup_iters = iters_per_epoch * args.lr_warmup_epochs + if args.lr_warmup_method == 'linear': + warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=args.lr_warmup_decay, + total_iters=warmup_iters) + elif args.lr_warmup_method == 'constant': + warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay, + total_iters=warmup_iters) + else: + raise RuntimeError("Invalid warmup lr method '{}'. Only linear and constant " + "are supported.".format(args.lr_warmup_method)) + + lr_scheduler = torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[warmup_lr_scheduler, main_lr_scheduler], + milestones=[warmup_iters] + ) + else: + lr_scheduler = main_lr_scheduler model_without_ddp = model if args.distributed: @@ -277,7 +293,9 @@ def parse_args(): dest='weight_decay') parser.add_argument('--lr-milestones', nargs='+', default=[20, 30, 40], type=int, help='decrease lr on milestones') parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') - parser.add_argument('--lr-warmup-epochs', default=10, type=int, help='number of warmup epochs') + parser.add_argument('--lr-warmup-epochs', default=10, type=int, help='the number of epochs to warmup (default: 10)') + parser.add_argument('--lr-warmup-method', default="linear", type=str, help='the warmup method (default: linear)') + parser.add_argument('--lr-warmup-decay', default=0.001, type=int, help='the decay for lr') parser.add_argument('--print-freq', default=10, type=int, help='print frequency') parser.add_argument('--output-dir', default='.', help='path where to save') parser.add_argument('--resume', default='', help='resume from checkpoint') From 0a64eba27e4293f407d53994914c405090d36185 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 16 Sep 2021 18:38:12 +0100 Subject: [PATCH 5/7] Adding support of both types of warmup in segmentation. --- references/segmentation/train.py | 14 ++++++++++++-- references/video_classification/train.py | 1 + 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 66b0a54101f..476058ce0c0 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -140,10 +140,19 @@ def main(args): if args.lr_warmup_epochs > 0: warmup_iters = iters_per_epoch * args.lr_warmup_epochs + args.lr_warmup_method = args.lr_warmup_method.lower() + if args.lr_warmup_method == 'linear': + warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=args.lr_warmup_decay, + total_iters=warmup_iters) + elif args.lr_warmup_method == 'constant': + warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay, + total_iters=warmup_iters) + else: + raise RuntimeError("Invalid warmup lr method '{}'. Only linear and constant " + "are supported.".format(args.lr_warmup_method)) lr_scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, - schedulers=[torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=args.lr_warmup_decay, - total_iters=warmup_iters), main_lr_scheduler], + schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters] ) else: @@ -210,6 +219,7 @@ def get_args_parser(add_help=True): metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay') parser.add_argument('--lr-warmup-epochs', default=0, type=int, help='the number of epochs to warmup (default: 0)') + parser.add_argument('--lr-warmup-method', default="linear", type=str, help='the warmup method (default: linear)') parser.add_argument('--lr-warmup-decay', default=0.01, type=int, help='the decay for lr') parser.add_argument('--print-freq', default=10, type=int, help='print frequency') parser.add_argument('--output-dir', default='.', help='path where to save') diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 41046ab56a1..353e0d6d1f7 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -206,6 +206,7 @@ def main(args): if args.lr_warmup_epochs > 0: warmup_iters = iters_per_epoch * args.lr_warmup_epochs + args.lr_warmup_method = args.lr_warmup_method.lower() if args.lr_warmup_method == 'linear': warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=args.lr_warmup_decay, total_iters=warmup_iters) From f95a13dd7edc93709519aa90a77f90dd43b7b0f7 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 16 Sep 2021 19:16:05 +0100 Subject: [PATCH 6/7] Use LinearLR in detection. --- references/detection/engine.py | 3 ++- references/detection/utils.py | 11 ----------- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/references/detection/engine.py b/references/detection/engine.py index 49992af60a9..82c23c178b1 100644 --- a/references/detection/engine.py +++ b/references/detection/engine.py @@ -21,7 +21,8 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq): warmup_factor = 1. / 1000 warmup_iters = min(1000, len(data_loader) - 1) - lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor) + lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=warmup_factor, + total_iters=warmup_iters) for images, targets in metric_logger.log_every(data_loader, print_freq, header): images = list(image.to(device) for image in images) diff --git a/references/detection/utils.py b/references/detection/utils.py index 3c52abb2167..11fcd3060e4 100644 --- a/references/detection/utils.py +++ b/references/detection/utils.py @@ -207,17 +207,6 @@ def collate_fn(batch): return tuple(zip(*batch)) -def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor): - - def f(x): - if x >= warmup_iters: - return 1 - alpha = float(x) / warmup_iters - return warmup_factor * (1 - alpha) + alpha - - return torch.optim.lr_scheduler.LambdaLR(optimizer, f) - - def mkdir(path): try: os.makedirs(path) From ba7d1d82ad76e3c2e22334803c578974db10fe04 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 16 Sep 2021 20:17:02 +0100 Subject: [PATCH 7/7] Fix deprecation warning. --- references/segmentation/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/segmentation/transforms.py b/references/segmentation/transforms.py index 4fe5a5ad147..fc828f4bab2 100644 --- a/references/segmentation/transforms.py +++ b/references/segmentation/transforms.py @@ -37,7 +37,7 @@ def __init__(self, min_size, max_size=None): def __call__(self, image, target): size = random.randint(self.min_size, self.max_size) image = F.resize(image, size) - target = F.resize(target, size, interpolation=Image.NEAREST) + target = F.resize(target, size, interpolation=T.InterpolationMode.NEAREST) return image, target