Skip to content

Commit 3aa2a93

Browse files
RAFT training reference Improvement (#5590)
* Change optical flow train.py function name from validate to evaluate so it is similar to other references * Add --device as parameter and enable to run in non distributed mode * Format with ufmt * Fix unneccessary param and bug * Enable saving the optimizer and scheduler on the checkpoint * Fix bug when evaluate before resume and save or load model without ddp * Fix case where --train-dataset is None Co-authored-by: Nicolas Hug <[email protected]>
1 parent 7be2f55 commit 3aa2a93

File tree

2 files changed

+85
-44
lines changed

2 files changed

+85
-44
lines changed

references/optical_flow/train.py

Lines changed: 79 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,21 @@ def get_train_dataset(stage, dataset_root):
6060

6161

6262
@torch.no_grad()
63-
def _validate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, batch_size=None, header=None):
63+
def _evaluate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, batch_size=None, header=None):
6464
"""Helper function to compute various metrics (epe, etc.) for a model on a given dataset.
6565
6666
We process as many samples as possible with ddp, and process the rest on a single worker.
6767
"""
6868
batch_size = batch_size or args.batch_size
69+
device = torch.device(args.device)
6970

7071
model.eval()
7172

72-
sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
73+
if args.distributed:
74+
sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
75+
else:
76+
sampler = torch.utils.data.SequentialSampler(val_dataset)
77+
7378
val_loader = torch.utils.data.DataLoader(
7479
val_dataset,
7580
sampler=sampler,
@@ -88,7 +93,7 @@ def inner_loop(blob):
8893
image1, image2, flow_gt = blob[:3]
8994
valid_flow_mask = None if len(blob) == 3 else blob[-1]
9095

91-
image1, image2 = image1.cuda(), image2.cuda()
96+
image1, image2 = image1.to(device), image2.to(device)
9297

9398
padder = utils.InputPadder(image1.shape, mode=padder_mode)
9499
image1, image2 = padder.pad(image1, image2)
@@ -115,21 +120,22 @@ def inner_loop(blob):
115120
inner_loop(blob)
116121
num_processed_samples += blob[0].shape[0] # batch size
117122

118-
num_processed_samples = utils.reduce_across_processes(num_processed_samples)
119-
print(
120-
f"Batch-processed {num_processed_samples} / {len(val_dataset)} samples. "
121-
"Going to process the remaining samples individually, if any."
122-
)
123+
if args.distributed:
124+
num_processed_samples = utils.reduce_across_processes(num_processed_samples)
125+
print(
126+
f"Batch-processed {num_processed_samples} / {len(val_dataset)} samples. "
127+
"Going to process the remaining samples individually, if any."
128+
)
129+
if args.rank == 0: # we only need to process the rest on a single worker
130+
for i in range(num_processed_samples, len(val_dataset)):
131+
inner_loop(val_dataset[i])
123132

124-
if args.rank == 0: # we only need to process the rest on a single worker
125-
for i in range(num_processed_samples, len(val_dataset)):
126-
inner_loop(val_dataset[i])
133+
logger.synchronize_between_processes()
127134

128-
logger.synchronize_between_processes()
129135
print(header, logger)
130136

131137

132-
def validate(model, args):
138+
def evaluate(model, args):
133139
val_datasets = args.val_dataset or []
134140

135141
if args.prototype:
@@ -145,21 +151,21 @@ def validate(model, args):
145151
if name == "kitti":
146152
# Kitti has different image sizes so we need to individually pad them, we can't batch.
147153
# see comment in InputPadder
148-
if args.batch_size != 1 and args.rank == 0:
154+
if args.batch_size != 1 and (not args.distributed or args.rank == 0):
149155
warnings.warn(
150156
f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1."
151157
)
152158

153159
val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=preprocessing)
154-
_validate(
160+
_evaluate(
155161
model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1
156162
)
157163
elif name == "sintel":
158164
for pass_name in ("clean", "final"):
159165
val_dataset = Sintel(
160166
root=args.dataset_root, split="train", pass_name=pass_name, transforms=preprocessing
161167
)
162-
_validate(
168+
_evaluate(
163169
model,
164170
args,
165171
val_dataset,
@@ -172,11 +178,12 @@ def validate(model, args):
172178

173179

174180
def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args):
181+
device = torch.device(args.device)
175182
for data_blob in logger.log_every(train_loader):
176183

177184
optimizer.zero_grad()
178185

179-
image1, image2, flow_gt, valid_flow_mask = (x.cuda() for x in data_blob)
186+
image1, image2, flow_gt, valid_flow_mask = (x.to(device) for x in data_blob)
180187
flow_predictions = model(image1, image2, num_flow_updates=args.num_flow_updates)
181188

182189
loss = utils.sequence_loss(flow_predictions, flow_gt, valid_flow_mask, args.gamma)
@@ -200,36 +207,68 @@ def main(args):
200207
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
201208
utils.setup_ddp(args)
202209

210+
if args.distributed and args.device == "cpu":
211+
raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun")
212+
device = torch.device(args.device)
213+
203214
if args.prototype:
204215
model = prototype.models.optical_flow.__dict__[args.model](weights=args.weights)
205216
else:
206217
model = torchvision.models.optical_flow.__dict__[args.model](pretrained=args.pretrained)
207218

208-
model = model.to(args.local_rank)
209-
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
219+
if args.distributed:
220+
model = model.to(args.local_rank)
221+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
222+
model_without_ddp = model.module
223+
else:
224+
model.to(device)
225+
model_without_ddp = model
210226

211227
if args.resume is not None:
212-
d = torch.load(args.resume, map_location="cpu")
213-
model.load_state_dict(d, strict=True)
228+
checkpoint = torch.load(args.resume, map_location="cpu")
229+
model_without_ddp.load_state_dict(checkpoint["model"])
214230

215231
if args.train_dataset is None:
216232
# Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
217233
torch.backends.cudnn.benchmark = False
218234
torch.backends.cudnn.deterministic = True
219-
validate(model, args)
235+
evaluate(model, args)
220236
return
221237

222238
print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
223239

240+
train_dataset = get_train_dataset(args.train_dataset, args.dataset_root)
241+
242+
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.adamw_eps)
243+
244+
scheduler = torch.optim.lr_scheduler.OneCycleLR(
245+
optimizer=optimizer,
246+
max_lr=args.lr,
247+
epochs=args.epochs,
248+
steps_per_epoch=ceil(len(train_dataset) / (args.world_size * args.batch_size)),
249+
pct_start=0.05,
250+
cycle_momentum=False,
251+
anneal_strategy="linear",
252+
)
253+
254+
if args.resume is not None:
255+
optimizer.load_state_dict(checkpoint["optimizer"])
256+
scheduler.load_state_dict(checkpoint["scheduler"])
257+
args.start_epoch = checkpoint["epoch"] + 1
258+
else:
259+
args.start_epoch = 0
260+
224261
torch.backends.cudnn.benchmark = True
225262

226263
model.train()
227264
if args.freeze_batch_norm:
228265
utils.freeze_batch_norm(model.module)
229266

230-
train_dataset = get_train_dataset(args.train_dataset, args.dataset_root)
267+
if args.distributed:
268+
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True)
269+
else:
270+
sampler = torch.utils.data.RandomSampler(train_dataset)
231271

232-
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True)
233272
train_loader = torch.utils.data.DataLoader(
234273
train_dataset,
235274
sampler=sampler,
@@ -238,25 +277,15 @@ def main(args):
238277
num_workers=args.num_workers,
239278
)
240279

241-
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.adamw_eps)
242-
243-
scheduler = torch.optim.lr_scheduler.OneCycleLR(
244-
optimizer=optimizer,
245-
max_lr=args.lr,
246-
epochs=args.epochs,
247-
steps_per_epoch=ceil(len(train_dataset) / (args.world_size * args.batch_size)),
248-
pct_start=0.05,
249-
cycle_momentum=False,
250-
anneal_strategy="linear",
251-
)
252-
253280
logger = utils.MetricLogger()
254281

255282
done = False
256-
for current_epoch in range(args.epochs):
283+
for current_epoch in range(args.start_epoch, args.epochs):
257284
print(f"EPOCH {current_epoch}")
285+
if args.distributed:
286+
# needed on distributed mode, otherwise the data loading order would be the same for all epochs
287+
sampler.set_epoch(current_epoch)
258288

259-
sampler.set_epoch(current_epoch) # needed, otherwise the data loading order would be the same for all epochs
260289
train_one_epoch(
261290
model=model,
262291
optimizer=optimizer,
@@ -269,13 +298,19 @@ def main(args):
269298
# Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
270299
print(f"Epoch {current_epoch} done. ", logger)
271300

272-
if args.rank == 0:
273-
# TODO: Also save the optimizer and scheduler
274-
torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}_{current_epoch}.pth")
275-
torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}.pth")
301+
if not args.distributed or args.rank == 0:
302+
checkpoint = {
303+
"model": model_without_ddp.state_dict(),
304+
"optimizer": optimizer.state_dict(),
305+
"scheduler": scheduler.state_dict(),
306+
"epoch": current_epoch,
307+
"args": args,
308+
}
309+
torch.save(checkpoint, Path(args.output_dir) / f"{args.name}_{current_epoch}.pth")
310+
torch.save(checkpoint, Path(args.output_dir) / f"{args.name}.pth")
276311

277312
if current_epoch % args.val_freq == 0 or done:
278-
validate(model, args)
313+
evaluate(model, args)
279314
model.train()
280315
if args.freeze_batch_norm:
281316
utils.freeze_batch_norm(model.module)
@@ -349,6 +384,7 @@ def get_args_parser(add_help=True):
349384
action="store_true",
350385
)
351386
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.")
387+
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu, Default: cuda)")
352388

353389
return parser
354390

references/optical_flow/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,12 @@ def setup_ddp(args):
256256
# if we're here, the script was called by run_with_submitit.py
257257
args.local_rank = args.gpu
258258
else:
259-
raise ValueError(r"Sorry, I can't set up the distributed training ¯\_(ツ)_/¯.")
259+
print("Not using distributed mode!")
260+
args.distributed = False
261+
args.world_size = 1
262+
return
263+
264+
args.distributed = True
260265

261266
_redefine_print(is_main=(args.rank == 0))
262267

0 commit comments

Comments
 (0)