@@ -221,10 +221,6 @@ def main(args):
221
221
model = model .to (args .local_rank )
222
222
model = torch .nn .parallel .DistributedDataParallel (model , device_ids = [args .local_rank ])
223
223
224
- if args .resume is not None :
225
- d = torch .load (args .resume , map_location = "cpu" )
226
- model .load_state_dict (d , strict = True )
227
-
228
224
if args .train_dataset is None :
229
225
# Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
230
226
torch .backends .cudnn .benchmark = False
@@ -234,14 +230,35 @@ def main(args):
234
230
235
231
print (f"Parameter Count: { sum (p .numel () for p in model .parameters () if p .requires_grad )} " )
236
232
233
+ train_dataset = get_train_dataset (args .train_dataset , args .dataset_root )
234
+
235
+ optimizer = torch .optim .AdamW (model .parameters (), lr = args .lr , weight_decay = args .weight_decay , eps = args .adamw_eps )
236
+
237
+ scheduler = torch .optim .lr_scheduler .OneCycleLR (
238
+ optimizer = optimizer ,
239
+ max_lr = args .lr ,
240
+ epochs = args .epochs ,
241
+ steps_per_epoch = ceil (len (train_dataset ) / (args .world_size * args .batch_size )),
242
+ pct_start = 0.05 ,
243
+ cycle_momentum = False ,
244
+ anneal_strategy = "linear" ,
245
+ )
246
+
247
+ if args .resume is not None :
248
+ checkpoint = torch .load (args .resume , map_location = "cpu" )
249
+ model .load_state_dict (checkpoint ["model" ])
250
+ optimizer .load_state_dict (checkpoint ["optimizer" ])
251
+ scheduler .load_state_dict (checkpoint ["scheduler" ])
252
+ args .start_epoch = checkpoint ["epoch" ] + 1
253
+ else :
254
+ args .start_epoch = 0
255
+
237
256
torch .backends .cudnn .benchmark = True
238
257
239
258
model .train ()
240
259
if args .freeze_batch_norm :
241
260
utils .freeze_batch_norm (model .module )
242
261
243
- train_dataset = get_train_dataset (args .train_dataset , args .dataset_root )
244
-
245
262
if args .distributed :
246
263
sampler = torch .utils .data .distributed .DistributedSampler (train_dataset , shuffle = True , drop_last = True )
247
264
else :
@@ -255,22 +272,10 @@ def main(args):
255
272
num_workers = args .num_workers ,
256
273
)
257
274
258
- optimizer = torch .optim .AdamW (model .parameters (), lr = args .lr , weight_decay = args .weight_decay , eps = args .adamw_eps )
259
-
260
- scheduler = torch .optim .lr_scheduler .OneCycleLR (
261
- optimizer = optimizer ,
262
- max_lr = args .lr ,
263
- epochs = args .epochs ,
264
- steps_per_epoch = ceil (len (train_dataset ) / (args .world_size * args .batch_size )),
265
- pct_start = 0.05 ,
266
- cycle_momentum = False ,
267
- anneal_strategy = "linear" ,
268
- )
269
-
270
275
logger = utils .MetricLogger ()
271
276
272
277
done = False
273
- for current_epoch in range (args .epochs ):
278
+ for current_epoch in range (args .start_epoch , args . epochs ):
274
279
print (f"EPOCH { current_epoch } " )
275
280
if args .distributed :
276
281
# needed on distributed mode, otherwise the data loading order would be the same for all epochs
@@ -289,9 +294,15 @@ def main(args):
289
294
print (f"Epoch { current_epoch } done. " , logger )
290
295
291
296
if not args .distributed or args .rank == 0 :
292
- # TODO: Also save the optimizer and scheduler
293
- torch .save (model .state_dict (), Path (args .output_dir ) / f"{ args .name } _{ current_epoch } .pth" )
294
- torch .save (model .state_dict (), Path (args .output_dir ) / f"{ args .name } .pth" )
297
+ checkpoint = {
298
+ "model" : model .state_dict (),
299
+ "optimizer" : optimizer .state_dict (),
300
+ "scheduler" : scheduler .state_dict (),
301
+ "epoch" : current_epoch ,
302
+ "args" : args ,
303
+ }
304
+ torch .save (checkpoint , Path (args .output_dir ) / f"{ args .name } _{ current_epoch } .pth" )
305
+ torch .save (checkpoint , Path (args .output_dir ) / f"{ args .name } .pth" )
295
306
296
307
if current_epoch % args .val_freq == 0 or done :
297
308
evaluate (model , args )
0 commit comments