Skip to content

Commit 06f1c46

Browse files
YosuaMichaelfacebook-github-bot
authored andcommitted
[fbsync] Minor updates to optical flow ref for consistency (#5654)
Summary: * Minor updates to optical flow ref for consistency * Actually put back name * linting Reviewed By: jdsgomes, NicolasHug Differential Revision: D36095658 fbshipit-source-id: c189a55787696d35811143d370192f5067b924a5
1 parent d03705e commit 06f1c46

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

references/optical_flow/train.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def _evaluate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, b
7575
sampler=sampler,
7676
batch_size=batch_size,
7777
pin_memory=True,
78-
num_workers=args.num_workers,
78+
num_workers=args.workers,
7979
)
8080

8181
num_flow_updates = num_flow_updates or args.num_flow_updates
@@ -269,17 +269,17 @@ def main(args):
269269
sampler=sampler,
270270
batch_size=args.batch_size,
271271
pin_memory=True,
272-
num_workers=args.num_workers,
272+
num_workers=args.workers,
273273
)
274274

275275
logger = utils.MetricLogger()
276276

277277
done = False
278-
for current_epoch in range(args.start_epoch, args.epochs):
279-
print(f"EPOCH {current_epoch}")
278+
for epoch in range(args.start_epoch, args.epochs):
279+
print(f"EPOCH {epoch}")
280280
if args.distributed:
281281
# needed on distributed mode, otherwise the data loading order would be the same for all epochs
282-
sampler.set_epoch(current_epoch)
282+
sampler.set_epoch(epoch)
283283

284284
train_one_epoch(
285285
model=model,
@@ -291,20 +291,20 @@ def main(args):
291291
)
292292

293293
# Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
294-
print(f"Epoch {current_epoch} done. ", logger)
294+
print(f"Epoch {epoch} done. ", logger)
295295

296296
if not args.distributed or args.rank == 0:
297297
checkpoint = {
298298
"model": model_without_ddp.state_dict(),
299299
"optimizer": optimizer.state_dict(),
300300
"scheduler": scheduler.state_dict(),
301-
"epoch": current_epoch,
301+
"epoch": epoch,
302302
"args": args,
303303
}
304-
torch.save(checkpoint, Path(args.output_dir) / f"{args.name}_{current_epoch}.pth")
304+
torch.save(checkpoint, Path(args.output_dir) / f"{args.name}_{epoch}.pth")
305305
torch.save(checkpoint, Path(args.output_dir) / f"{args.name}.pth")
306306

307-
if current_epoch % args.val_freq == 0 or done:
307+
if epoch % args.val_freq == 0 or done:
308308
evaluate(model, args)
309309
model.train()
310310
if args.freeze_batch_norm:
@@ -319,16 +319,14 @@ def get_args_parser(add_help=True):
319319
type=str,
320320
help="The name of the experiment - determines the name of the files where weights are saved.",
321321
)
322-
parser.add_argument(
323-
"--output-dir", default="checkpoints", type=str, help="Output dir where checkpoints will be stored."
324-
)
322+
parser.add_argument("--output-dir", default=".", type=str, help="Output dir where checkpoints will be stored.")
325323
parser.add_argument(
326324
"--resume",
327325
type=str,
328326
help="A path to previously saved weights. Used to re-start training from, or evaluate a pre-saved model.",
329327
)
330328

331-
parser.add_argument("--num-workers", type=int, default=12, help="Number of workers for the data loading part.")
329+
parser.add_argument("--workers", type=int, default=12, help="Number of workers for the data loading part.")
332330

333331
parser.add_argument(
334332
"--train-dataset",

0 commit comments

Comments
 (0)