Skip to content

Commit 09d78d1

Browse files
committed
Enable saving the optimizer and scheduler on the checkpoint
1 parent 2276b22 commit 09d78d1

File tree

1 file changed

+33
-22
lines changed

1 file changed

+33
-22
lines changed

references/optical_flow/train.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -221,10 +221,6 @@ def main(args):
221221
model = model.to(args.local_rank)
222222
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
223223

224-
if args.resume is not None:
225-
d = torch.load(args.resume, map_location="cpu")
226-
model.load_state_dict(d, strict=True)
227-
228224
if args.train_dataset is None:
229225
# Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
230226
torch.backends.cudnn.benchmark = False
@@ -234,14 +230,35 @@ def main(args):
234230

235231
print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
236232

233+
train_dataset = get_train_dataset(args.train_dataset, args.dataset_root)
234+
235+
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.adamw_eps)
236+
237+
scheduler = torch.optim.lr_scheduler.OneCycleLR(
238+
optimizer=optimizer,
239+
max_lr=args.lr,
240+
epochs=args.epochs,
241+
steps_per_epoch=ceil(len(train_dataset) / (args.world_size * args.batch_size)),
242+
pct_start=0.05,
243+
cycle_momentum=False,
244+
anneal_strategy="linear",
245+
)
246+
247+
if args.resume is not None:
248+
checkpoint = torch.load(args.resume, map_location="cpu")
249+
model.load_state_dict(checkpoint["model"])
250+
optimizer.load_state_dict(checkpoint["optimizer"])
251+
scheduler.load_state_dict(checkpoint["scheduler"])
252+
args.start_epoch = checkpoint["epoch"] + 1
253+
else:
254+
args.start_epoch = 0
255+
237256
torch.backends.cudnn.benchmark = True
238257

239258
model.train()
240259
if args.freeze_batch_norm:
241260
utils.freeze_batch_norm(model.module)
242261

243-
train_dataset = get_train_dataset(args.train_dataset, args.dataset_root)
244-
245262
if args.distributed:
246263
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True)
247264
else:
@@ -255,22 +272,10 @@ def main(args):
255272
num_workers=args.num_workers,
256273
)
257274

258-
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.adamw_eps)
259-
260-
scheduler = torch.optim.lr_scheduler.OneCycleLR(
261-
optimizer=optimizer,
262-
max_lr=args.lr,
263-
epochs=args.epochs,
264-
steps_per_epoch=ceil(len(train_dataset) / (args.world_size * args.batch_size)),
265-
pct_start=0.05,
266-
cycle_momentum=False,
267-
anneal_strategy="linear",
268-
)
269-
270275
logger = utils.MetricLogger()
271276

272277
done = False
273-
for current_epoch in range(args.epochs):
278+
for current_epoch in range(args.start_epoch, args.epochs):
274279
print(f"EPOCH {current_epoch}")
275280
if args.distributed:
276281
# needed on distributed mode, otherwise the data loading order would be the same for all epochs
@@ -289,9 +294,15 @@ def main(args):
289294
print(f"Epoch {current_epoch} done. ", logger)
290295

291296
if not args.distributed or args.rank == 0:
292-
# TODO: Also save the optimizer and scheduler
293-
torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}_{current_epoch}.pth")
294-
torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}.pth")
297+
checkpoint = {
298+
"model": model.state_dict(),
299+
"optimizer": optimizer.state_dict(),
300+
"scheduler": scheduler.state_dict(),
301+
"epoch": current_epoch,
302+
"args": args,
303+
}
304+
torch.save(checkpoint, Path(args.output_dir) / f"{args.name}_{current_epoch}.pth")
305+
torch.save(checkpoint, Path(args.output_dir) / f"{args.name}.pth")
295306

296307
if current_epoch % args.val_freq == 0 or done:
297308
evaluate(model, args)

0 commit comments

Comments
 (0)