From 37901e8805739c6bdc6ce3408da597deea00f19d Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 21 Mar 2022 17:15:26 +0000 Subject: [PATCH 1/3] Minor updates to optical flow ref for consistency --- references/optical_flow/train.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index 83952242eb9..6d1be58f959 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -80,7 +80,7 @@ def _evaluate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, b sampler=sampler, batch_size=batch_size, pin_memory=True, - num_workers=args.num_workers, + num_workers=args.workers, ) num_flow_updates = num_flow_updates or args.num_flow_updates @@ -274,17 +274,17 @@ def main(args): sampler=sampler, batch_size=args.batch_size, pin_memory=True, - num_workers=args.num_workers, + num_workers=args.workers, ) logger = utils.MetricLogger() done = False - for current_epoch in range(args.start_epoch, args.epochs): - print(f"EPOCH {current_epoch}") + for epoch in range(args.start_epoch, args.epochs): + print(f"EPOCH {epoch}") if args.distributed: # needed on distributed mode, otherwise the data loading order would be the same for all epochs - sampler.set_epoch(current_epoch) + sampler.set_epoch(epoch) train_one_epoch( model=model, @@ -296,20 +296,20 @@ def main(args): ) # Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0 - print(f"Epoch {current_epoch} done. ", logger) + print(f"Epoch {epoch} done. ", logger) if not args.distributed or args.rank == 0: checkpoint = { "model": model_without_ddp.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), - "epoch": current_epoch, + "epoch": epoch, "args": args, } - torch.save(checkpoint, Path(args.output_dir) / f"{args.name}_{current_epoch}.pth") - torch.save(checkpoint, Path(args.output_dir) / f"{args.name}.pth") + torch.save(checkpoint, Path(args.output_dir) / f"model_{epoch}.pth") + torch.save(checkpoint, Path(args.output_dir) / f"checkpoint.pth") - if current_epoch % args.val_freq == 0 or done: + if epoch % args.val_freq == 0 or done: evaluate(model, args) model.train() if args.freeze_batch_norm: @@ -319,13 +319,7 @@ def main(args): def get_args_parser(add_help=True): parser = argparse.ArgumentParser(add_help=add_help, description="Train or evaluate an optical-flow model.") parser.add_argument( - "--name", - default="raft", - type=str, - help="The name of the experiment - determines the name of the files where weights are saved.", - ) - parser.add_argument( - "--output-dir", default="checkpoints", type=str, help="Output dir where checkpoints will be stored." + "--output-dir", default=".", type=str, help="Output dir where checkpoints will be stored." ) parser.add_argument( "--resume", @@ -333,7 +327,7 @@ def get_args_parser(add_help=True): help="A path to previously saved weights. Used to re-start training from, or evaluate a pre-saved model.", ) - parser.add_argument("--num-workers", type=int, default=12, help="Number of workers for the data loading part.") + parser.add_argument("--workers", type=int, default=12, help="Number of workers for the data loading part.") parser.add_argument( "--train-dataset", From 84e0a70ee370955b4bc11834b3dca879fe0dd03c Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 21 Mar 2022 17:21:30 +0000 Subject: [PATCH 2/3] Actually put back name --- references/optical_flow/train.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index 6d1be58f959..6a3c543046c 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -306,8 +306,8 @@ def main(args): "epoch": epoch, "args": args, } - torch.save(checkpoint, Path(args.output_dir) / f"model_{epoch}.pth") - torch.save(checkpoint, Path(args.output_dir) / f"checkpoint.pth") + torch.save(checkpoint, Path(args.output_dir) / f"{args.name}_{epoch}.pth") + torch.save(checkpoint, Path(args.output_dir) / f"{args.name}.pth") if epoch % args.val_freq == 0 or done: evaluate(model, args) @@ -318,6 +318,12 @@ def main(args): def get_args_parser(add_help=True): parser = argparse.ArgumentParser(add_help=add_help, description="Train or evaluate an optical-flow model.") + parser.add_argument( + "--name", + default="raft", + type=str, + help="The name of the experiment - determines the name of the files where weights are saved.", + ) parser.add_argument( "--output-dir", default=".", type=str, help="Output dir where checkpoints will be stored." ) From 8bad400e5c6e58f661dddad7cb25c8aa787059ae Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 22 Mar 2022 09:01:28 +0000 Subject: [PATCH 3/3] linting --- references/optical_flow/train.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index 6a3c543046c..ee5d14a6508 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -324,9 +324,7 @@ def get_args_parser(add_help=True): type=str, help="The name of the experiment - determines the name of the files where weights are saved.", ) - parser.add_argument( - "--output-dir", default=".", type=str, help="Output dir where checkpoints will be stored." - ) + parser.add_argument("--output-dir", default=".", type=str, help="Output dir where checkpoints will be stored.") parser.add_argument( "--resume", type=str,