@@ -221,10 +221,6 @@ def main(args):
221221 model = model .to (args .local_rank )
222222 model = torch .nn .parallel .DistributedDataParallel (model , device_ids = [args .local_rank ])
223223
224- if args .resume is not None :
225- d = torch .load (args .resume , map_location = "cpu" )
226- model .load_state_dict (d , strict = True )
227-
228224 if args .train_dataset is None :
229225 # Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
230226 torch .backends .cudnn .benchmark = False
@@ -234,14 +230,35 @@ def main(args):
234230
235231 print (f"Parameter Count: { sum (p .numel () for p in model .parameters () if p .requires_grad )} " )
236232
233+ train_dataset = get_train_dataset (args .train_dataset , args .dataset_root )
234+
235+ optimizer = torch .optim .AdamW (model .parameters (), lr = args .lr , weight_decay = args .weight_decay , eps = args .adamw_eps )
236+
237+ scheduler = torch .optim .lr_scheduler .OneCycleLR (
238+ optimizer = optimizer ,
239+ max_lr = args .lr ,
240+ epochs = args .epochs ,
241+ steps_per_epoch = ceil (len (train_dataset ) / (args .world_size * args .batch_size )),
242+ pct_start = 0.05 ,
243+ cycle_momentum = False ,
244+ anneal_strategy = "linear" ,
245+ )
246+
247+ if args .resume is not None :
248+ checkpoint = torch .load (args .resume , map_location = "cpu" )
249+ model .load_state_dict (checkpoint ["model" ])
250+ optimizer .load_state_dict (checkpoint ["optimizer" ])
251+ scheduler .load_state_dict (checkpoint ["scheduler" ])
252+ args .start_epoch = checkpoint ["epoch" ] + 1
253+ else :
254+ args .start_epoch = 0
255+
237256 torch .backends .cudnn .benchmark = True
238257
239258 model .train ()
240259 if args .freeze_batch_norm :
241260 utils .freeze_batch_norm (model .module )
242261
243- train_dataset = get_train_dataset (args .train_dataset , args .dataset_root )
244-
245262 if args .distributed :
246263 sampler = torch .utils .data .distributed .DistributedSampler (train_dataset , shuffle = True , drop_last = True )
247264 else :
@@ -255,22 +272,10 @@ def main(args):
255272 num_workers = args .num_workers ,
256273 )
257274
258- optimizer = torch .optim .AdamW (model .parameters (), lr = args .lr , weight_decay = args .weight_decay , eps = args .adamw_eps )
259-
260- scheduler = torch .optim .lr_scheduler .OneCycleLR (
261- optimizer = optimizer ,
262- max_lr = args .lr ,
263- epochs = args .epochs ,
264- steps_per_epoch = ceil (len (train_dataset ) / (args .world_size * args .batch_size )),
265- pct_start = 0.05 ,
266- cycle_momentum = False ,
267- anneal_strategy = "linear" ,
268- )
269-
270275 logger = utils .MetricLogger ()
271276
272277 done = False
273- for current_epoch in range (args .epochs ):
278+ for current_epoch in range (args .start_epoch , args . epochs ):
274279 print (f"EPOCH { current_epoch } " )
275280 if args .distributed :
276281 # needed on distributed mode, otherwise the data loading order would be the same for all epochs
@@ -289,9 +294,15 @@ def main(args):
289294 print (f"Epoch { current_epoch } done. " , logger )
290295
291296 if not args .distributed or args .rank == 0 :
292- # TODO: Also save the optimizer and scheduler
293- torch .save (model .state_dict (), Path (args .output_dir ) / f"{ args .name } _{ current_epoch } .pth" )
294- torch .save (model .state_dict (), Path (args .output_dir ) / f"{ args .name } .pth" )
297+ checkpoint = {
298+ "model" : model .state_dict (),
299+ "optimizer" : optimizer .state_dict (),
300+ "scheduler" : scheduler .state_dict (),
301+ "epoch" : current_epoch ,
302+ "args" : args ,
303+ }
304+ torch .save (checkpoint , Path (args .output_dir ) / f"{ args .name } _{ current_epoch } .pth" )
305+ torch .save (checkpoint , Path (args .output_dir ) / f"{ args .name } .pth" )
295306
296307 if current_epoch % args .val_freq == 0 or done :
297308 evaluate (model , args )
0 commit comments