1
1
import argparse
2
2
import warnings
3
+ from math import ceil
3
4
from pathlib import Path
4
5
5
6
import torch
@@ -168,7 +169,7 @@ def validate(model, args):
168
169
warnings .warn (f"Can't validate on { val_dataset } , skipping." )
169
170
170
171
171
- def train_one_epoch (model , optimizer , scheduler , train_loader , logger , current_step , args ):
172
+ def train_one_epoch (model , optimizer , scheduler , train_loader , logger , args ):
172
173
for data_blob in logger .log_every (train_loader ):
173
174
174
175
optimizer .zero_grad ()
@@ -189,13 +190,6 @@ def train_one_epoch(model, optimizer, scheduler, train_loader, logger, current_s
189
190
optimizer .step ()
190
191
scheduler .step ()
191
192
192
- current_step += 1
193
-
194
- if current_step == args .num_steps :
195
- return True , current_step
196
-
197
- return False , current_step
198
-
199
193
200
194
def main (args ):
201
195
utils .setup_ddp (args )
@@ -243,7 +237,8 @@ def main(args):
243
237
scheduler = torch .optim .lr_scheduler .OneCycleLR (
244
238
optimizer = optimizer ,
245
239
max_lr = args .lr ,
246
- total_steps = args .num_steps + 100 ,
240
+ epochs = args .epochs ,
241
+ steps_per_epoch = ceil (len (train_dataset ) / (args .world_size * args .batch_size )),
247
242
pct_start = 0.05 ,
248
243
cycle_momentum = False ,
249
244
anneal_strategy = "linear" ,
@@ -252,26 +247,22 @@ def main(args):
252
247
logger = utils .MetricLogger ()
253
248
254
249
done = False
255
- current_epoch = current_step = 0
256
- while not done :
250
+ for current_epoch in range (args .epochs ):
257
251
print (f"EPOCH { current_epoch } " )
258
252
259
253
sampler .set_epoch (current_epoch ) # needed, otherwise the data loading order would be the same for all epochs
260
- done , current_step = train_one_epoch (
254
+ train_one_epoch (
261
255
model = model ,
262
256
optimizer = optimizer ,
263
257
scheduler = scheduler ,
264
258
train_loader = train_loader ,
265
259
logger = logger ,
266
- current_step = current_step ,
267
260
args = args ,
268
261
)
269
262
270
263
# Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
271
264
print (f"Epoch { current_epoch } done. " , logger )
272
265
273
- current_epoch += 1
274
-
275
266
if args .rank == 0 :
276
267
# TODO: Also save the optimizer and scheduler
277
268
torch .save (model .state_dict (), Path (args .output_dir ) / f"{ args .name } _{ current_epoch } .pth" )
@@ -310,10 +301,8 @@ def get_args_parser(add_help=True):
310
301
)
311
302
parser .add_argument ("--val-dataset" , type = str , nargs = "+" , help = "The dataset(s) to use for validation." )
312
303
parser .add_argument ("--val-freq" , type = int , default = 2 , help = "Validate every X epochs" )
313
- # TODO: eventually, it might be preferable to support epochs instead of num_steps.
314
- # Keeping it this way for now to reproduce results more easily.
315
- parser .add_argument ("--num-steps" , type = int , default = 100000 , help = "The total number of steps (updates) to train." )
316
- parser .add_argument ("--batch-size" , type = int , default = 6 )
304
+ parser .add_argument ("--epochs" , type = int , default = 20 , help = "The total number of epochs to train." )
305
+ parser .add_argument ("--batch-size" , type = int , default = 2 )
317
306
318
307
parser .add_argument ("--lr" , type = float , default = 0.00002 , help = "Learning rate for AdamW optimizer" )
319
308
parser .add_argument ("--weight-decay" , type = float , default = 0.00005 , help = "Weight decay for AdamW optimizer" )
0 commit comments