Skip to content

Commit 83a09cd

Browse files
committed
Fix bug when evaluate before resume and save or load model without ddp
1 parent 3a4d43a commit 83a09cd

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

references/optical_flow/train.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ def inner_loop(blob):
130130
for i in range(num_processed_samples, len(val_dataset)):
131131
inner_loop(val_dataset[i])
132132

133-
logger.synchronize_between_processes()
133+
logger.synchronize_between_processes()
134+
134135
print(header, logger)
135136

136137

@@ -215,18 +216,13 @@ def main(args):
215216
else:
216217
model = torchvision.models.optical_flow.__dict__[args.model](pretrained=args.pretrained)
217218

218-
model.to(device)
219-
220219
if args.distributed:
221220
model = model.to(args.local_rank)
222221
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
223-
224-
if args.train_dataset is None:
225-
# Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
226-
torch.backends.cudnn.benchmark = False
227-
torch.backends.cudnn.deterministic = True
228-
evaluate(model, args)
229-
return
222+
model_without_ddp = model.module
223+
else:
224+
model.to(device)
225+
model_without_ddp = model
230226

231227
print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
232228

@@ -246,13 +242,20 @@ def main(args):
246242

247243
if args.resume is not None:
248244
checkpoint = torch.load(args.resume, map_location="cpu")
249-
model.load_state_dict(checkpoint["model"])
245+
model_without_ddp.load_state_dict(checkpoint["model"])
250246
optimizer.load_state_dict(checkpoint["optimizer"])
251247
scheduler.load_state_dict(checkpoint["scheduler"])
252248
args.start_epoch = checkpoint["epoch"] + 1
253249
else:
254250
args.start_epoch = 0
255251

252+
if args.train_dataset is None:
253+
# Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
254+
torch.backends.cudnn.benchmark = False
255+
torch.backends.cudnn.deterministic = True
256+
evaluate(model, args)
257+
return
258+
256259
torch.backends.cudnn.benchmark = True
257260

258261
model.train()
@@ -295,7 +298,7 @@ def main(args):
295298

296299
if not args.distributed or args.rank == 0:
297300
checkpoint = {
298-
"model": model.state_dict(),
301+
"model": model_without_ddp.state_dict(),
299302
"optimizer": optimizer.state_dict(),
300303
"scheduler": scheduler.state_dict(),
301304
"epoch": current_epoch,

0 commit comments

Comments
 (0)