@@ -224,6 +224,17 @@ def main(args):
224224 model .to (device )
225225 model_without_ddp = model
226226
227+ if args .resume is not None :
228+ checkpoint = torch .load (args .resume , map_location = "cpu" )
229+ model_without_ddp .load_state_dict (checkpoint ["model" ])
230+
231+ if args .train_dataset is None :
232+ # Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
233+ torch .backends .cudnn .benchmark = False
234+ torch .backends .cudnn .deterministic = True
235+ evaluate (model , args )
236+ return
237+
227238 print (f"Parameter Count: { sum (p .numel () for p in model .parameters () if p .requires_grad )} " )
228239
229240 train_dataset = get_train_dataset (args .train_dataset , args .dataset_root )
@@ -241,21 +252,12 @@ def main(args):
241252 )
242253
243254 if args .resume is not None :
244- checkpoint = torch .load (args .resume , map_location = "cpu" )
245- model_without_ddp .load_state_dict (checkpoint ["model" ])
246255 optimizer .load_state_dict (checkpoint ["optimizer" ])
247256 scheduler .load_state_dict (checkpoint ["scheduler" ])
248257 args .start_epoch = checkpoint ["epoch" ] + 1
249258 else :
250259 args .start_epoch = 0
251260
252- if args .train_dataset is None :
253- # Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
254- torch .backends .cudnn .benchmark = False
255- torch .backends .cudnn .deterministic = True
256- evaluate (model , args )
257- return
258-
259261 torch .backends .cudnn .benchmark = True
260262
261263 model .train ()
0 commit comments