11import argparse
22import warnings
3+ from math import ceil
34from pathlib import Path
45
56import torch
@@ -168,7 +169,7 @@ def validate(model, args):
168169 warnings .warn (f"Can't validate on { val_dataset } , skipping." )
169170
170171
171- def train_one_epoch (model , optimizer , scheduler , train_loader , logger , current_step , args ):
172+ def train_one_epoch (model , optimizer , scheduler , train_loader , logger , args ):
172173 for data_blob in logger .log_every (train_loader ):
173174
174175 optimizer .zero_grad ()
@@ -189,13 +190,6 @@ def train_one_epoch(model, optimizer, scheduler, train_loader, logger, current_s
189190 optimizer .step ()
190191 scheduler .step ()
191192
192- current_step += 1
193-
194- if current_step == args .num_steps :
195- return True , current_step
196-
197- return False , current_step
198-
199193
200194def main (args ):
201195 utils .setup_ddp (args )
@@ -243,7 +237,8 @@ def main(args):
243237 scheduler = torch .optim .lr_scheduler .OneCycleLR (
244238 optimizer = optimizer ,
245239 max_lr = args .lr ,
246- total_steps = args .num_steps + 100 ,
240+ epochs = args .epochs ,
241+ steps_per_epoch = ceil (len (train_dataset ) / (args .world_size * args .batch_size )),
247242 pct_start = 0.05 ,
248243 cycle_momentum = False ,
249244 anneal_strategy = "linear" ,
@@ -252,26 +247,22 @@ def main(args):
252247 logger = utils .MetricLogger ()
253248
254249 done = False
255- current_epoch = current_step = 0
256- while not done :
250+ for current_epoch in range (args .epochs ):
257251 print (f"EPOCH { current_epoch } " )
258252
259253 sampler .set_epoch (current_epoch ) # needed, otherwise the data loading order would be the same for all epochs
260- done , current_step = train_one_epoch (
254+ train_one_epoch (
261255 model = model ,
262256 optimizer = optimizer ,
263257 scheduler = scheduler ,
264258 train_loader = train_loader ,
265259 logger = logger ,
266- current_step = current_step ,
267260 args = args ,
268261 )
269262
270263 # Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
271264 print (f"Epoch { current_epoch } done. " , logger )
272265
273- current_epoch += 1
274-
275266 if args .rank == 0 :
276267 # TODO: Also save the optimizer and scheduler
277268 torch .save (model .state_dict (), Path (args .output_dir ) / f"{ args .name } _{ current_epoch } .pth" )
@@ -310,10 +301,8 @@ def get_args_parser(add_help=True):
310301 )
311302 parser .add_argument ("--val-dataset" , type = str , nargs = "+" , help = "The dataset(s) to use for validation." )
312303 parser .add_argument ("--val-freq" , type = int , default = 2 , help = "Validate every X epochs" )
313- # TODO: eventually, it might be preferable to support epochs instead of num_steps.
314- # Keeping it this way for now to reproduce results more easily.
315- parser .add_argument ("--num-steps" , type = int , default = 100000 , help = "The total number of steps (updates) to train." )
316- parser .add_argument ("--batch-size" , type = int , default = 6 )
304+ parser .add_argument ("--epochs" , type = int , default = 20 , help = "The total number of epochs to train." )
305+ parser .add_argument ("--batch-size" , type = int , default = 2 )
317306
318307 parser .add_argument ("--lr" , type = float , default = 0.00002 , help = "Learning rate for AdamW optimizer" )
319308 parser .add_argument ("--weight-decay" , type = float , default = 0.00005 , help = "Weight decay for AdamW optimizer" )
0 commit comments