@@ -224,6 +224,17 @@ def main(args):
224
224
model .to (device )
225
225
model_without_ddp = model
226
226
227
+ if args .resume is not None :
228
+ checkpoint = torch .load (args .resume , map_location = "cpu" )
229
+ model_without_ddp .load_state_dict (checkpoint ["model" ])
230
+
231
+ if args .train_dataset is None :
232
+ # Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
233
+ torch .backends .cudnn .benchmark = False
234
+ torch .backends .cudnn .deterministic = True
235
+ evaluate (model , args )
236
+ return
237
+
227
238
print (f"Parameter Count: { sum (p .numel () for p in model .parameters () if p .requires_grad )} " )
228
239
229
240
train_dataset = get_train_dataset (args .train_dataset , args .dataset_root )
@@ -241,21 +252,12 @@ def main(args):
241
252
)
242
253
243
254
if args .resume is not None :
244
- checkpoint = torch .load (args .resume , map_location = "cpu" )
245
- model_without_ddp .load_state_dict (checkpoint ["model" ])
246
255
optimizer .load_state_dict (checkpoint ["optimizer" ])
247
256
scheduler .load_state_dict (checkpoint ["scheduler" ])
248
257
args .start_epoch = checkpoint ["epoch" ] + 1
249
258
else :
250
259
args .start_epoch = 0
251
260
252
- if args .train_dataset is None :
253
- # Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
254
- torch .backends .cudnn .benchmark = False
255
- torch .backends .cudnn .deterministic = True
256
- evaluate (model , args )
257
- return
258
-
259
261
torch .backends .cudnn .benchmark = True
260
262
261
263
model .train ()
0 commit comments