@@ -130,7 +130,8 @@ def inner_loop(blob):
130
130
for i in range (num_processed_samples , len (val_dataset )):
131
131
inner_loop (val_dataset [i ])
132
132
133
- logger .synchronize_between_processes ()
133
+ logger .synchronize_between_processes ()
134
+
134
135
print (header , logger )
135
136
136
137
@@ -215,18 +216,13 @@ def main(args):
215
216
else :
216
217
model = torchvision .models .optical_flow .__dict__ [args .model ](pretrained = args .pretrained )
217
218
218
- model .to (device )
219
-
220
219
if args .distributed :
221
220
model = model .to (args .local_rank )
222
221
model = torch .nn .parallel .DistributedDataParallel (model , device_ids = [args .local_rank ])
223
-
224
- if args .train_dataset is None :
225
- # Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
226
- torch .backends .cudnn .benchmark = False
227
- torch .backends .cudnn .deterministic = True
228
- evaluate (model , args )
229
- return
222
+ model_without_ddp = model .module
223
+ else :
224
+ model .to (device )
225
+ model_without_ddp = model
230
226
231
227
print (f"Parameter Count: { sum (p .numel () for p in model .parameters () if p .requires_grad )} " )
232
228
@@ -246,13 +242,20 @@ def main(args):
246
242
247
243
if args .resume is not None :
248
244
checkpoint = torch .load (args .resume , map_location = "cpu" )
249
- model .load_state_dict (checkpoint ["model" ])
245
+ model_without_ddp .load_state_dict (checkpoint ["model" ])
250
246
optimizer .load_state_dict (checkpoint ["optimizer" ])
251
247
scheduler .load_state_dict (checkpoint ["scheduler" ])
252
248
args .start_epoch = checkpoint ["epoch" ] + 1
253
249
else :
254
250
args .start_epoch = 0
255
251
252
+ if args .train_dataset is None :
253
+ # Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
254
+ torch .backends .cudnn .benchmark = False
255
+ torch .backends .cudnn .deterministic = True
256
+ evaluate (model , args )
257
+ return
258
+
256
259
torch .backends .cudnn .benchmark = True
257
260
258
261
model .train ()
@@ -295,7 +298,7 @@ def main(args):
295
298
296
299
if not args .distributed or args .rank == 0 :
297
300
checkpoint = {
298
- "model" : model .state_dict (),
301
+ "model" : model_without_ddp .state_dict (),
299
302
"optimizer" : optimizer .state_dict (),
300
303
"scheduler" : scheduler .state_dict (),
301
304
"epoch" : current_epoch ,
0 commit comments