diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index 5070cb554d4..9b88c83df3a 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -75,7 +75,7 @@ def _evaluate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, b sampler=sampler, batch_size=batch_size, pin_memory=True, - num_workers=args.num_workers, + num_workers=args.workers, ) num_flow_updates = num_flow_updates or args.num_flow_updates @@ -269,17 +269,17 @@ def main(args): sampler=sampler, batch_size=args.batch_size, pin_memory=True, - num_workers=args.num_workers, + num_workers=args.workers, ) logger = utils.MetricLogger() done = False - for current_epoch in range(args.start_epoch, args.epochs): - print(f"EPOCH {current_epoch}") + for epoch in range(args.start_epoch, args.epochs): + print(f"EPOCH {epoch}") if args.distributed: # needed on distributed mode, otherwise the data loading order would be the same for all epochs - sampler.set_epoch(current_epoch) + sampler.set_epoch(epoch) train_one_epoch( model=model, @@ -291,20 +291,20 @@ def main(args): ) # Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0 - print(f"Epoch {current_epoch} done. ", logger) + print(f"Epoch {epoch} done. ", logger) if not args.distributed or args.rank == 0: checkpoint = { "model": model_without_ddp.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), - "epoch": current_epoch, + "epoch": epoch, "args": args, } - torch.save(checkpoint, Path(args.output_dir) / f"{args.name}_{current_epoch}.pth") + torch.save(checkpoint, Path(args.output_dir) / f"{args.name}_{epoch}.pth") torch.save(checkpoint, Path(args.output_dir) / f"{args.name}.pth") - if current_epoch % args.val_freq == 0 or done: + if epoch % args.val_freq == 0 or done: evaluate(model, args) model.train() if args.freeze_batch_norm: @@ -319,16 +319,14 @@ def get_args_parser(add_help=True): type=str, help="The name of the experiment - determines the name of the files where weights are saved.", ) - parser.add_argument( - "--output-dir", default="checkpoints", type=str, help="Output dir where checkpoints will be stored." - ) + parser.add_argument("--output-dir", default=".", type=str, help="Output dir where checkpoints will be stored.") parser.add_argument( "--resume", type=str, help="A path to previously saved weights. Used to re-start training from, or evaluate a pre-saved model.", ) - parser.add_argument("--num-workers", type=int, default=12, help="Number of workers for the data loading part.") + parser.add_argument("--workers", type=int, default=12, help="Number of workers for the data loading part.") parser.add_argument( "--train-dataset",