Skip to content

Commit 2857e21

Browse files
committed
Fix case where --train-dataset is None
1 parent 83a09cd commit 2857e21

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

references/optical_flow/train.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,17 @@ def main(args):
224224
model.to(device)
225225
model_without_ddp = model
226226

227+
if args.resume is not None:
228+
checkpoint = torch.load(args.resume, map_location="cpu")
229+
model_without_ddp.load_state_dict(checkpoint["model"])
230+
231+
if args.train_dataset is None:
232+
# Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
233+
torch.backends.cudnn.benchmark = False
234+
torch.backends.cudnn.deterministic = True
235+
evaluate(model, args)
236+
return
237+
227238
print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
228239

229240
train_dataset = get_train_dataset(args.train_dataset, args.dataset_root)
@@ -241,21 +252,12 @@ def main(args):
241252
)
242253

243254
if args.resume is not None:
244-
checkpoint = torch.load(args.resume, map_location="cpu")
245-
model_without_ddp.load_state_dict(checkpoint["model"])
246255
optimizer.load_state_dict(checkpoint["optimizer"])
247256
scheduler.load_state_dict(checkpoint["scheduler"])
248257
args.start_epoch = checkpoint["epoch"] + 1
249258
else:
250259
args.start_epoch = 0
251260

252-
if args.train_dataset is None:
253-
# Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
254-
torch.backends.cudnn.benchmark = False
255-
torch.backends.cudnn.deterministic = True
256-
evaluate(model, args)
257-
return
258-
259261
torch.backends.cudnn.benchmark = True
260262

261263
model.train()

0 commit comments

Comments
 (0)