Skip to content

Commit cf9a5aa

Browse files
Vincent Moensfacebook-github-bot
Vincent Moens
authored andcommitted
[fbsync] Add training reference for optical flow models (#5027)
Reviewed By: NicolasHug Differential Revision: D32950938 fbshipit-source-id: 0f271d45026c821c109493d9aa7f404b5373012d
1 parent e310ab8 commit cf9a5aa

File tree

2 files changed

+616
-0
lines changed

2 files changed

+616
-0
lines changed

references/optical_flow/train.py

Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
import argparse
2+
import warnings
3+
from pathlib import Path
4+
5+
import torch
6+
import utils
7+
from presets import OpticalFlowPresetTrain, OpticalFlowPresetEval
8+
from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K
9+
from torchvision.models.optical_flow import raft_large, raft_small
10+
11+
12+
def get_train_dataset(stage, dataset_root):
13+
if stage == "chairs":
14+
transforms = OpticalFlowPresetTrain(crop_size=(368, 496), min_scale=0.1, max_scale=1.0, do_flip=True)
15+
return FlyingChairs(root=dataset_root, split="train", transforms=transforms)
16+
elif stage == "things":
17+
transforms = OpticalFlowPresetTrain(crop_size=(400, 720), min_scale=-0.4, max_scale=0.8, do_flip=True)
18+
return FlyingThings3D(root=dataset_root, split="train", pass_name="both", transforms=transforms)
19+
elif stage == "sintel_SKH": # S + K + H as from paper
20+
crop_size = (368, 768)
21+
transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.2, max_scale=0.6, do_flip=True)
22+
23+
things_clean = FlyingThings3D(root=dataset_root, split="train", pass_name="clean", transforms=transforms)
24+
sintel = Sintel(root=dataset_root, split="train", pass_name="both", transforms=transforms)
25+
26+
kitti_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.3, max_scale=0.5, do_flip=True)
27+
kitti = KittiFlow(root=dataset_root, split="train", transforms=kitti_transforms)
28+
29+
hd1k_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.5, max_scale=0.2, do_flip=True)
30+
hd1k = HD1K(root=dataset_root, split="train", transforms=hd1k_transforms)
31+
32+
# As future improvement, we could probably be using a distributed sampler here
33+
# The distribution is S(.71), T(.135), K(.135), H(.02)
34+
return 100 * sintel + 200 * kitti + 5 * hd1k + things_clean
35+
elif stage == "kitti":
36+
transforms = OpticalFlowPresetTrain(
37+
# resize and crop params
38+
crop_size=(288, 960),
39+
min_scale=-0.2,
40+
max_scale=0.4,
41+
stretch_prob=0,
42+
# flip params
43+
do_flip=False,
44+
# jitter params
45+
brightness=0.3,
46+
contrast=0.3,
47+
saturation=0.3,
48+
hue=0.3 / 3.14,
49+
asymmetric_jitter_prob=0,
50+
)
51+
return KittiFlow(root=dataset_root, split="train", transforms=transforms)
52+
else:
53+
raise ValueError(f"Unknown stage {stage}")
54+
55+
56+
@torch.no_grad()
57+
def _validate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, batch_size=None, header=None):
58+
"""Helper function to compute various metrics (epe, etc.) for a model on a given dataset.
59+
60+
We process as many samples as possible with ddp, and process the rest on a single worker.
61+
"""
62+
batch_size = batch_size or args.batch_size
63+
64+
model.eval()
65+
66+
sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
67+
val_loader = torch.utils.data.DataLoader(
68+
val_dataset,
69+
sampler=sampler,
70+
batch_size=batch_size,
71+
pin_memory=True,
72+
num_workers=args.num_workers,
73+
)
74+
75+
num_flow_updates = num_flow_updates or args.num_flow_updates
76+
77+
def inner_loop(blob):
78+
if blob[0].dim() == 3:
79+
# input is not batched so we add an extra dim for consistency
80+
blob = [x[None, :, :, :] if x is not None else None for x in blob]
81+
82+
image1, image2, flow_gt = blob[:3]
83+
valid_flow_mask = None if len(blob) == 3 else blob[-1]
84+
85+
image1, image2 = image1.cuda(), image2.cuda()
86+
87+
padder = utils.InputPadder(image1.shape, mode=padder_mode)
88+
image1, image2 = padder.pad(image1, image2)
89+
90+
flow_predictions = model(image1, image2, num_flow_updates=num_flow_updates)
91+
flow_pred = flow_predictions[-1]
92+
flow_pred = padder.unpad(flow_pred).cpu()
93+
94+
metrics, num_pixels_tot = utils.compute_metrics(flow_pred, flow_gt, valid_flow_mask)
95+
96+
# We compute per-pixel epe (epe) and per-image epe (called f1-epe in RAFT paper).
97+
# per-pixel epe: average epe of all pixels of all images
98+
# per-image epe: average epe on each image independently, then average over images
99+
for name in ("epe", "1px", "3px", "5px", "f1"): # f1 is called f1-all in paper
100+
logger.meters[name].update(metrics[name], n=num_pixels_tot)
101+
logger.meters["per_image_epe"].update(metrics["epe"], n=batch_size)
102+
103+
logger = utils.MetricLogger()
104+
for meter_name in ("epe", "1px", "3px", "5px", "per_image_epe", "f1"):
105+
logger.add_meter(meter_name, fmt="{global_avg:.4f}")
106+
107+
num_processed_samples = 0
108+
for blob in logger.log_every(val_loader, header=header, print_freq=None):
109+
inner_loop(blob)
110+
num_processed_samples += blob[0].shape[0] # batch size
111+
112+
num_processed_samples = utils.reduce_across_processes(num_processed_samples)
113+
print(
114+
f"Batch-processed {num_processed_samples} / {len(val_dataset)} samples. "
115+
"Going to process the remaining samples individually, if any."
116+
)
117+
118+
if args.rank == 0: # we only need to process the rest on a single worker
119+
for i in range(num_processed_samples, len(val_dataset)):
120+
inner_loop(val_dataset[i])
121+
122+
logger.synchronize_between_processes()
123+
print(header, logger)
124+
125+
126+
def validate(model, args):
127+
val_datasets = args.val_dataset or []
128+
for name in val_datasets:
129+
if name == "kitti":
130+
# Kitti has different image sizes so we need to individually pad them, we can't batch.
131+
# see comment in InputPadder
132+
if args.batch_size != 1 and args.rank == 0:
133+
warnings.warn(
134+
f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1."
135+
)
136+
137+
val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=OpticalFlowPresetEval())
138+
_validate(
139+
model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1
140+
)
141+
elif name == "sintel":
142+
for pass_name in ("clean", "final"):
143+
val_dataset = Sintel(
144+
root=args.dataset_root, split="train", pass_name=pass_name, transforms=OpticalFlowPresetEval()
145+
)
146+
_validate(
147+
model,
148+
args,
149+
val_dataset,
150+
num_flow_updates=32,
151+
padder_mode="sintel",
152+
header=f"Sintel val {pass_name}",
153+
)
154+
else:
155+
warnings.warn(f"Can't validate on {val_dataset}, skipping.")
156+
157+
158+
def train_one_epoch(model, optimizer, scheduler, train_loader, logger, current_step, args):
159+
for data_blob in logger.log_every(train_loader):
160+
161+
optimizer.zero_grad()
162+
163+
image1, image2, flow_gt, valid_flow_mask = (x.cuda() for x in data_blob)
164+
flow_predictions = model(image1, image2, num_flow_updates=args.num_flow_updates)
165+
166+
loss = utils.sequence_loss(flow_predictions, flow_gt, valid_flow_mask, args.gamma)
167+
metrics, _ = utils.compute_metrics(flow_predictions[-1], flow_gt, valid_flow_mask)
168+
169+
metrics.pop("f1")
170+
logger.update(loss=loss, **metrics)
171+
172+
loss.backward()
173+
174+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
175+
176+
optimizer.step()
177+
scheduler.step()
178+
179+
current_step += 1
180+
181+
if current_step == args.num_steps:
182+
return True, current_step
183+
184+
return False, current_step
185+
186+
187+
def main(args):
188+
utils.setup_ddp(args)
189+
190+
model = raft_small() if args.small else raft_large()
191+
model = model.to(args.local_rank)
192+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
193+
194+
if args.resume is not None:
195+
d = torch.load(args.resume, map_location="cpu")
196+
model.load_state_dict(d, strict=True)
197+
198+
if args.train_dataset is None:
199+
# Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
200+
torch.backends.cudnn.benchmark = False
201+
torch.backends.cudnn.deterministic = True
202+
validate(model, args)
203+
return
204+
205+
print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
206+
207+
torch.backends.cudnn.benchmark = True
208+
209+
model.train()
210+
if args.freeze_batch_norm:
211+
utils.freeze_batch_norm(model.module)
212+
213+
train_dataset = get_train_dataset(args.train_dataset, args.dataset_root)
214+
215+
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True)
216+
train_loader = torch.utils.data.DataLoader(
217+
train_dataset,
218+
sampler=sampler,
219+
batch_size=args.batch_size,
220+
pin_memory=True,
221+
num_workers=args.num_workers,
222+
)
223+
224+
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.adamw_eps)
225+
226+
scheduler = torch.optim.lr_scheduler.OneCycleLR(
227+
optimizer=optimizer,
228+
max_lr=args.lr,
229+
total_steps=args.num_steps + 100,
230+
pct_start=0.05,
231+
cycle_momentum=False,
232+
anneal_strategy="linear",
233+
)
234+
235+
logger = utils.MetricLogger()
236+
237+
done = False
238+
current_epoch = current_step = 0
239+
while not done:
240+
print(f"EPOCH {current_epoch}")
241+
242+
sampler.set_epoch(current_epoch) # needed, otherwise the data loading order would be the same for all epochs
243+
done, current_step = train_one_epoch(
244+
model=model,
245+
optimizer=optimizer,
246+
scheduler=scheduler,
247+
train_loader=train_loader,
248+
logger=logger,
249+
current_step=current_step,
250+
args=args,
251+
)
252+
253+
# Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
254+
print(f"Epoch {current_epoch} done. ", logger)
255+
256+
current_epoch += 1
257+
258+
if args.rank == 0:
259+
# TODO: Also save the optimizer and scheduler
260+
torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}_{current_epoch}.pth")
261+
torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}.pth")
262+
263+
if current_epoch % args.val_freq == 0 or done:
264+
validate(model, args)
265+
model.train()
266+
if args.freeze_batch_norm:
267+
utils.freeze_batch_norm(model.module)
268+
269+
270+
def get_args_parser(add_help=True):
271+
parser = argparse.ArgumentParser(add_help=add_help, description="Train or evaluate an optical-flow model.")
272+
parser.add_argument(
273+
"--name",
274+
default="raft",
275+
type=str,
276+
help="The name of the experiment - determines the name of the files where weights are saved.",
277+
)
278+
parser.add_argument(
279+
"--output-dir", default="checkpoints", type=str, help="Output dir where checkpoints will be stored."
280+
)
281+
parser.add_argument(
282+
"--resume",
283+
type=str,
284+
help="A path to previously saved weights. Used to re-start training from, or evaluate a pre-saved model.",
285+
)
286+
287+
parser.add_argument("--num-workers", type=int, default=12, help="Number of workers for the data loading part.")
288+
289+
parser.add_argument(
290+
"--train-dataset",
291+
type=str,
292+
help="The dataset to use for training. If not passed, only validation is performed (and you probably want to pass --resume).",
293+
)
294+
parser.add_argument("--val-dataset", type=str, nargs="+", help="The dataset(s) to use for validation.")
295+
parser.add_argument("--val-freq", type=int, default=2, help="Validate every X epochs")
296+
# TODO: eventually, it might be preferable to support epochs instead of num_steps.
297+
# Keeping it this way for now to reproduce results more easily.
298+
parser.add_argument("--num-steps", type=int, default=100000, help="The total number of steps (updates) to train.")
299+
parser.add_argument("--batch-size", type=int, default=6)
300+
301+
parser.add_argument("--lr", type=float, default=0.00002, help="Learning rate for AdamW optimizer")
302+
parser.add_argument("--weight-decay", type=float, default=0.00005, help="Weight decay for AdamW optimizer")
303+
parser.add_argument("--adamw-eps", type=float, default=1e-8, help="eps value for AdamW optimizer")
304+
305+
parser.add_argument(
306+
"--freeze-batch-norm", action="store_true", help="Set BatchNorm modules of the model in eval mode."
307+
)
308+
309+
parser.add_argument("--small", action="store_true", help="Use the 'small' RAFT architecture.")
310+
311+
parser.add_argument(
312+
"--num_flow_updates",
313+
type=int,
314+
default=12,
315+
help="number of updates (or 'iters') in the update operator of the model.",
316+
)
317+
318+
parser.add_argument("--gamma", type=float, default=0.8, help="exponential weighting for loss. Must be < 1.")
319+
320+
parser.add_argument("--dist-url", default="env://", help="URL used to set up distributed training")
321+
322+
parser.add_argument(
323+
"--dataset-root",
324+
help="Root folder where the datasets are stored. Will be passed as the 'root' parameter of the datasets.",
325+
required=True,
326+
)
327+
328+
return parser
329+
330+
331+
if __name__ == "__main__":
332+
args = get_args_parser().parse_args()
333+
Path(args.output_dir).mkdir(exist_ok=True)
334+
main(args)

0 commit comments

Comments
 (0)