diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index 7f2f362c73d..83952242eb9 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -60,16 +60,21 @@ def get_train_dataset(stage, dataset_root): @torch.no_grad() -def _validate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, batch_size=None, header=None): +def _evaluate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, batch_size=None, header=None): """Helper function to compute various metrics (epe, etc.) for a model on a given dataset. We process as many samples as possible with ddp, and process the rest on a single worker. """ batch_size = batch_size or args.batch_size + device = torch.device(args.device) model.eval() - sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True) + if args.distributed: + sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True) + else: + sampler = torch.utils.data.SequentialSampler(val_dataset) + val_loader = torch.utils.data.DataLoader( val_dataset, sampler=sampler, @@ -88,7 +93,7 @@ def inner_loop(blob): image1, image2, flow_gt = blob[:3] valid_flow_mask = None if len(blob) == 3 else blob[-1] - image1, image2 = image1.cuda(), image2.cuda() + image1, image2 = image1.to(device), image2.to(device) padder = utils.InputPadder(image1.shape, mode=padder_mode) image1, image2 = padder.pad(image1, image2) @@ -115,21 +120,22 @@ def inner_loop(blob): inner_loop(blob) num_processed_samples += blob[0].shape[0] # batch size - num_processed_samples = utils.reduce_across_processes(num_processed_samples) - print( - f"Batch-processed {num_processed_samples} / {len(val_dataset)} samples. " - "Going to process the remaining samples individually, if any." - ) + if args.distributed: + num_processed_samples = utils.reduce_across_processes(num_processed_samples) + print( + f"Batch-processed {num_processed_samples} / {len(val_dataset)} samples. " + "Going to process the remaining samples individually, if any." + ) + if args.rank == 0: # we only need to process the rest on a single worker + for i in range(num_processed_samples, len(val_dataset)): + inner_loop(val_dataset[i]) - if args.rank == 0: # we only need to process the rest on a single worker - for i in range(num_processed_samples, len(val_dataset)): - inner_loop(val_dataset[i]) + logger.synchronize_between_processes() - logger.synchronize_between_processes() print(header, logger) -def validate(model, args): +def evaluate(model, args): val_datasets = args.val_dataset or [] if args.prototype: @@ -145,13 +151,13 @@ def validate(model, args): if name == "kitti": # Kitti has different image sizes so we need to individually pad them, we can't batch. # see comment in InputPadder - if args.batch_size != 1 and args.rank == 0: + if args.batch_size != 1 and (not args.distributed or args.rank == 0): warnings.warn( f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1." ) val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=preprocessing) - _validate( + _evaluate( model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1 ) elif name == "sintel": @@ -159,7 +165,7 @@ def validate(model, args): val_dataset = Sintel( root=args.dataset_root, split="train", pass_name=pass_name, transforms=preprocessing ) - _validate( + _evaluate( model, args, val_dataset, @@ -172,11 +178,12 @@ def validate(model, args): def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args): + device = torch.device(args.device) for data_blob in logger.log_every(train_loader): optimizer.zero_grad() - image1, image2, flow_gt, valid_flow_mask = (x.cuda() for x in data_blob) + image1, image2, flow_gt, valid_flow_mask = (x.to(device) for x in data_blob) flow_predictions = model(image1, image2, num_flow_updates=args.num_flow_updates) loss = utils.sequence_loss(flow_predictions, flow_gt, valid_flow_mask, args.gamma) @@ -200,36 +207,68 @@ def main(args): raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") utils.setup_ddp(args) + if args.distributed and args.device == "cpu": + raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun") + device = torch.device(args.device) + if args.prototype: model = prototype.models.optical_flow.__dict__[args.model](weights=args.weights) else: model = torchvision.models.optical_flow.__dict__[args.model](pretrained=args.pretrained) - model = model.to(args.local_rank) - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank]) + if args.distributed: + model = model.to(args.local_rank) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank]) + model_without_ddp = model.module + else: + model.to(device) + model_without_ddp = model if args.resume is not None: - d = torch.load(args.resume, map_location="cpu") - model.load_state_dict(d, strict=True) + checkpoint = torch.load(args.resume, map_location="cpu") + model_without_ddp.load_state_dict(checkpoint["model"]) if args.train_dataset is None: # Set deterministic CUDNN algorithms, since they can affect epe a fair bit. torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True - validate(model, args) + evaluate(model, args) return print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") + train_dataset = get_train_dataset(args.train_dataset, args.dataset_root) + + optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.adamw_eps) + + scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer=optimizer, + max_lr=args.lr, + epochs=args.epochs, + steps_per_epoch=ceil(len(train_dataset) / (args.world_size * args.batch_size)), + pct_start=0.05, + cycle_momentum=False, + anneal_strategy="linear", + ) + + if args.resume is not None: + optimizer.load_state_dict(checkpoint["optimizer"]) + scheduler.load_state_dict(checkpoint["scheduler"]) + args.start_epoch = checkpoint["epoch"] + 1 + else: + args.start_epoch = 0 + torch.backends.cudnn.benchmark = True model.train() if args.freeze_batch_norm: utils.freeze_batch_norm(model.module) - train_dataset = get_train_dataset(args.train_dataset, args.dataset_root) + if args.distributed: + sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True) + else: + sampler = torch.utils.data.RandomSampler(train_dataset) - sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True) train_loader = torch.utils.data.DataLoader( train_dataset, sampler=sampler, @@ -238,25 +277,15 @@ def main(args): num_workers=args.num_workers, ) - optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.adamw_eps) - - scheduler = torch.optim.lr_scheduler.OneCycleLR( - optimizer=optimizer, - max_lr=args.lr, - epochs=args.epochs, - steps_per_epoch=ceil(len(train_dataset) / (args.world_size * args.batch_size)), - pct_start=0.05, - cycle_momentum=False, - anneal_strategy="linear", - ) - logger = utils.MetricLogger() done = False - for current_epoch in range(args.epochs): + for current_epoch in range(args.start_epoch, args.epochs): print(f"EPOCH {current_epoch}") + if args.distributed: + # needed on distributed mode, otherwise the data loading order would be the same for all epochs + sampler.set_epoch(current_epoch) - sampler.set_epoch(current_epoch) # needed, otherwise the data loading order would be the same for all epochs train_one_epoch( model=model, optimizer=optimizer, @@ -269,13 +298,19 @@ def main(args): # Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0 print(f"Epoch {current_epoch} done. ", logger) - if args.rank == 0: - # TODO: Also save the optimizer and scheduler - torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}_{current_epoch}.pth") - torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}.pth") + if not args.distributed or args.rank == 0: + checkpoint = { + "model": model_without_ddp.state_dict(), + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + "epoch": current_epoch, + "args": args, + } + torch.save(checkpoint, Path(args.output_dir) / f"{args.name}_{current_epoch}.pth") + torch.save(checkpoint, Path(args.output_dir) / f"{args.name}.pth") if current_epoch % args.val_freq == 0 or done: - validate(model, args) + evaluate(model, args) model.train() if args.freeze_batch_norm: utils.freeze_batch_norm(model.module) @@ -349,6 +384,7 @@ def get_args_parser(add_help=True): action="store_true", ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.") + parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu, Default: cuda)") return parser diff --git a/references/optical_flow/utils.py b/references/optical_flow/utils.py index acdc49bd1f7..4b6d0049f54 100644 --- a/references/optical_flow/utils.py +++ b/references/optical_flow/utils.py @@ -256,7 +256,12 @@ def setup_ddp(args): # if we're here, the script was called by run_with_submitit.py args.local_rank = args.gpu else: - raise ValueError(r"Sorry, I can't set up the distributed training ¯\_(ツ)_/¯.") + print("Not using distributed mode!") + args.distributed = False + args.world_size = 1 + return + + args.distributed = True _redefine_print(is_main=(args.rank == 0))