@@ -75,7 +75,7 @@ def _evaluate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, b
75
75
sampler = sampler ,
76
76
batch_size = batch_size ,
77
77
pin_memory = True ,
78
- num_workers = args .num_workers ,
78
+ num_workers = args .workers ,
79
79
)
80
80
81
81
num_flow_updates = num_flow_updates or args .num_flow_updates
@@ -269,17 +269,17 @@ def main(args):
269
269
sampler = sampler ,
270
270
batch_size = args .batch_size ,
271
271
pin_memory = True ,
272
- num_workers = args .num_workers ,
272
+ num_workers = args .workers ,
273
273
)
274
274
275
275
logger = utils .MetricLogger ()
276
276
277
277
done = False
278
- for current_epoch in range (args .start_epoch , args .epochs ):
279
- print (f"EPOCH { current_epoch } " )
278
+ for epoch in range (args .start_epoch , args .epochs ):
279
+ print (f"EPOCH { epoch } " )
280
280
if args .distributed :
281
281
# needed on distributed mode, otherwise the data loading order would be the same for all epochs
282
- sampler .set_epoch (current_epoch )
282
+ sampler .set_epoch (epoch )
283
283
284
284
train_one_epoch (
285
285
model = model ,
@@ -291,20 +291,20 @@ def main(args):
291
291
)
292
292
293
293
# Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
294
- print (f"Epoch { current_epoch } done. " , logger )
294
+ print (f"Epoch { epoch } done. " , logger )
295
295
296
296
if not args .distributed or args .rank == 0 :
297
297
checkpoint = {
298
298
"model" : model_without_ddp .state_dict (),
299
299
"optimizer" : optimizer .state_dict (),
300
300
"scheduler" : scheduler .state_dict (),
301
- "epoch" : current_epoch ,
301
+ "epoch" : epoch ,
302
302
"args" : args ,
303
303
}
304
- torch .save (checkpoint , Path (args .output_dir ) / f"{ args .name } _{ current_epoch } .pth" )
304
+ torch .save (checkpoint , Path (args .output_dir ) / f"{ args .name } _{ epoch } .pth" )
305
305
torch .save (checkpoint , Path (args .output_dir ) / f"{ args .name } .pth" )
306
306
307
- if current_epoch % args .val_freq == 0 or done :
307
+ if epoch % args .val_freq == 0 or done :
308
308
evaluate (model , args )
309
309
model .train ()
310
310
if args .freeze_batch_norm :
@@ -319,16 +319,14 @@ def get_args_parser(add_help=True):
319
319
type = str ,
320
320
help = "The name of the experiment - determines the name of the files where weights are saved." ,
321
321
)
322
- parser .add_argument (
323
- "--output-dir" , default = "checkpoints" , type = str , help = "Output dir where checkpoints will be stored."
324
- )
322
+ parser .add_argument ("--output-dir" , default = "." , type = str , help = "Output dir where checkpoints will be stored." )
325
323
parser .add_argument (
326
324
"--resume" ,
327
325
type = str ,
328
326
help = "A path to previously saved weights. Used to re-start training from, or evaluate a pre-saved model." ,
329
327
)
330
328
331
- parser .add_argument ("--num- workers" , type = int , default = 12 , help = "Number of workers for the data loading part." )
329
+ parser .add_argument ("--workers" , type = int , default = 12 , help = "Number of workers for the data loading part." )
332
330
333
331
parser .add_argument (
334
332
"--train-dataset" ,
0 commit comments