diff --git a/references/classification/train.py b/references/classification/train.py index 3ec9039a018..90abdb0b47e 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -208,7 +208,25 @@ def main(args): opt_level=args.apex_opt_level ) - lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) + args.lr_scheduler = args.lr_scheduler.lower() + if args.lr_scheduler == 'steplr': + main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) + elif args.lr_scheduler == 'cosineannealinglr': + main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, + T_max=args.epochs - args.lr_warmup_epochs) + else: + raise RuntimeError("Invalid lr scheduler '{}'. Only StepLR and CosineAnnealingLR " + "are supported.".format(args.lr_scheduler)) + + if args.lr_warmup_epochs > 0: + lr_scheduler = torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay, + total_iters=args.lr_warmup_epochs), main_lr_scheduler], + milestones=[args.lr_warmup_epochs] + ) + else: + lr_scheduler = main_lr_scheduler model_without_ddp = model if args.distributed: @@ -287,6 +305,9 @@ def get_args_parser(add_help=True): dest='label_smoothing') parser.add_argument('--mixup-alpha', default=0.0, type=float, help='mixup alpha (default: 0.0)') parser.add_argument('--cutmix-alpha', default=0.0, type=float, help='cutmix alpha (default: 0.0)') + parser.add_argument('--lr-scheduler', default="steplr", help='the lr scheduler (default: steplr)') + parser.add_argument('--lr-warmup-epochs', default=0, type=int, help='the number of epochs to warmup (default: 0)') + parser.add_argument('--lr-warmup-decay', default=0.01, type=int, help='the decay for lr') parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs') parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') parser.add_argument('--print-freq', default=10, type=int, help='print frequency') diff --git a/references/detection/engine.py b/references/detection/engine.py index 49992af60a9..82c23c178b1 100644 --- a/references/detection/engine.py +++ b/references/detection/engine.py @@ -21,7 +21,8 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq): warmup_factor = 1. / 1000 warmup_iters = min(1000, len(data_loader) - 1) - lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor) + lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=warmup_factor, + total_iters=warmup_iters) for images, targets in metric_logger.log_every(data_loader, print_freq, header): images = list(image.to(device) for image in images) diff --git a/references/detection/utils.py b/references/detection/utils.py index 3c52abb2167..11fcd3060e4 100644 --- a/references/detection/utils.py +++ b/references/detection/utils.py @@ -207,17 +207,6 @@ def collate_fn(batch): return tuple(zip(*batch)) -def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor): - - def f(x): - if x >= warmup_iters: - return 1 - alpha = float(x) / warmup_iters - return warmup_factor * (1 - alpha) + alpha - - return torch.optim.lr_scheduler.LambdaLR(optimizer, f) - - def mkdir(path): try: os.makedirs(path) diff --git a/references/segmentation/train.py b/references/segmentation/train.py index fb6c7eeee15..476058ce0c0 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -133,9 +133,30 @@ def main(args): params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) - lr_scheduler = torch.optim.lr_scheduler.LambdaLR( + iters_per_epoch = len(data_loader) + main_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, - lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9) + lambda x: (1 - x / (iters_per_epoch * (args.epochs - args.lr_warmup_epochs))) ** 0.9) + + if args.lr_warmup_epochs > 0: + warmup_iters = iters_per_epoch * args.lr_warmup_epochs + args.lr_warmup_method = args.lr_warmup_method.lower() + if args.lr_warmup_method == 'linear': + warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=args.lr_warmup_decay, + total_iters=warmup_iters) + elif args.lr_warmup_method == 'constant': + warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay, + total_iters=warmup_iters) + else: + raise RuntimeError("Invalid warmup lr method '{}'. Only linear and constant " + "are supported.".format(args.lr_warmup_method)) + lr_scheduler = torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[warmup_lr_scheduler, main_lr_scheduler], + milestones=[warmup_iters] + ) + else: + lr_scheduler = main_lr_scheduler if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') @@ -197,6 +218,9 @@ def get_args_parser(add_help=True): parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay') + parser.add_argument('--lr-warmup-epochs', default=0, type=int, help='the number of epochs to warmup (default: 0)') + parser.add_argument('--lr-warmup-method', default="linear", type=str, help='the warmup method (default: linear)') + parser.add_argument('--lr-warmup-decay', default=0.01, type=int, help='the decay for lr') parser.add_argument('--print-freq', default=10, type=int, help='print frequency') parser.add_argument('--output-dir', default='.', help='path where to save') parser.add_argument('--resume', default='', help='resume from checkpoint') diff --git a/references/segmentation/transforms.py b/references/segmentation/transforms.py index 4fe5a5ad147..fc828f4bab2 100644 --- a/references/segmentation/transforms.py +++ b/references/segmentation/transforms.py @@ -37,7 +37,7 @@ def __init__(self, min_size, max_size=None): def __call__(self, image, target): size = random.randint(self.min_size, self.max_size) image = F.resize(image, size) - target = F.resize(target, size, interpolation=Image.NEAREST) + target = F.resize(target, size, interpolation=T.InterpolationMode.NEAREST) return image, target diff --git a/references/video_classification/scheduler.py b/references/video_classification/scheduler.py deleted file mode 100644 index f0f862d41ad..00000000000 --- a/references/video_classification/scheduler.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch -from bisect import bisect_right - - -class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): - def __init__( - self, - optimizer, - milestones, - gamma=0.1, - warmup_factor=1.0 / 3, - warmup_iters=5, - warmup_method="linear", - last_epoch=-1, - ): - if not milestones == sorted(milestones): - raise ValueError( - "Milestones should be a list of" " increasing integers. Got {}", - milestones, - ) - - if warmup_method not in ("constant", "linear"): - raise ValueError( - "Only 'constant' or 'linear' warmup_method accepted" - "got {}".format(warmup_method) - ) - self.milestones = milestones - self.gamma = gamma - self.warmup_factor = warmup_factor - self.warmup_iters = warmup_iters - self.warmup_method = warmup_method - super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) - - def get_lr(self): - warmup_factor = 1 - if self.last_epoch < self.warmup_iters: - if self.warmup_method == "constant": - warmup_factor = self.warmup_factor - elif self.warmup_method == "linear": - alpha = float(self.last_epoch) / self.warmup_iters - warmup_factor = self.warmup_factor * (1 - alpha) + alpha - return [ - base_lr * - warmup_factor * - self.gamma ** bisect_right(self.milestones, self.last_epoch) - for base_lr in self.base_lrs - ] diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 11ac2d5378d..353e0d6d1f7 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -12,8 +12,6 @@ import presets import utils -from scheduler import WarmupMultiStepLR - try: from apex import amp except ImportError: @@ -202,11 +200,30 @@ def main(args): # convert scheduler to be per iteration, not per epoch, for warmup that lasts # between different epochs - warmup_iters = args.lr_warmup_epochs * len(data_loader) - lr_milestones = [len(data_loader) * m for m in args.lr_milestones] - lr_scheduler = WarmupMultiStepLR( - optimizer, milestones=lr_milestones, gamma=args.lr_gamma, - warmup_iters=warmup_iters, warmup_factor=1e-5) + iters_per_epoch = len(data_loader) + lr_milestones = [iters_per_epoch * (m - args.lr_warmup_epochs) for m in args.lr_milestones] + main_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_milestones, gamma=args.lr_gamma) + + if args.lr_warmup_epochs > 0: + warmup_iters = iters_per_epoch * args.lr_warmup_epochs + args.lr_warmup_method = args.lr_warmup_method.lower() + if args.lr_warmup_method == 'linear': + warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=args.lr_warmup_decay, + total_iters=warmup_iters) + elif args.lr_warmup_method == 'constant': + warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay, + total_iters=warmup_iters) + else: + raise RuntimeError("Invalid warmup lr method '{}'. Only linear and constant " + "are supported.".format(args.lr_warmup_method)) + + lr_scheduler = torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[warmup_lr_scheduler, main_lr_scheduler], + milestones=[warmup_iters] + ) + else: + lr_scheduler = main_lr_scheduler model_without_ddp = model if args.distributed: @@ -277,7 +294,9 @@ def parse_args(): dest='weight_decay') parser.add_argument('--lr-milestones', nargs='+', default=[20, 30, 40], type=int, help='decrease lr on milestones') parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') - parser.add_argument('--lr-warmup-epochs', default=10, type=int, help='number of warmup epochs') + parser.add_argument('--lr-warmup-epochs', default=10, type=int, help='the number of epochs to warmup (default: 10)') + parser.add_argument('--lr-warmup-method', default="linear", type=str, help='the warmup method (default: linear)') + parser.add_argument('--lr-warmup-decay', default=0.001, type=int, help='the decay for lr') parser.add_argument('--print-freq', default=10, type=int, help='print frequency') parser.add_argument('--output-dir', default='.', help='path where to save') parser.add_argument('--resume', default='', help='resume from checkpoint')