-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add training reference for optical flow models #5027
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f566a78
4707883
df4d9a8
e49844f
2396716
2783402
1f7159f
be1d9f3
2ef1af5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,334 @@ | ||
import argparse | ||
import warnings | ||
from pathlib import Path | ||
|
||
import torch | ||
import utils | ||
from presets import OpticalFlowPresetTrain, OpticalFlowPresetEval | ||
from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K | ||
from torchvision.models.optical_flow import raft_large, raft_small | ||
|
||
|
||
def get_train_dataset(stage, dataset_root): | ||
if stage == "chairs": | ||
transforms = OpticalFlowPresetTrain(crop_size=(368, 496), min_scale=0.1, max_scale=1.0, do_flip=True) | ||
return FlyingChairs(root=dataset_root, split="train", transforms=transforms) | ||
elif stage == "things": | ||
transforms = OpticalFlowPresetTrain(crop_size=(400, 720), min_scale=-0.4, max_scale=0.8, do_flip=True) | ||
return FlyingThings3D(root=dataset_root, split="train", pass_name="both", transforms=transforms) | ||
elif stage == "sintel_SKH": # S + K + H as from paper | ||
crop_size = (368, 768) | ||
transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.2, max_scale=0.6, do_flip=True) | ||
|
||
things_clean = FlyingThings3D(root=dataset_root, split="train", pass_name="clean", transforms=transforms) | ||
sintel = Sintel(root=dataset_root, split="train", pass_name="both", transforms=transforms) | ||
|
||
kitti_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.3, max_scale=0.5, do_flip=True) | ||
kitti = KittiFlow(root=dataset_root, split="train", transforms=kitti_transforms) | ||
|
||
hd1k_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.5, max_scale=0.2, do_flip=True) | ||
hd1k = HD1K(root=dataset_root, split="train", transforms=hd1k_transforms) | ||
|
||
# As future improvement, we could probably be using a distributed sampler here | ||
# The distribution is S(.71), T(.135), K(.135), H(.02) | ||
return 100 * sintel + 200 * kitti + 5 * hd1k + things_clean | ||
elif stage == "kitti": | ||
transforms = OpticalFlowPresetTrain( | ||
# resize and crop params | ||
crop_size=(288, 960), | ||
min_scale=-0.2, | ||
max_scale=0.4, | ||
stretch_prob=0, | ||
# flip params | ||
do_flip=False, | ||
# jitter params | ||
brightness=0.3, | ||
contrast=0.3, | ||
saturation=0.3, | ||
hue=0.3 / 3.14, | ||
asymmetric_jitter_prob=0, | ||
) | ||
return KittiFlow(root=dataset_root, split="train", transforms=transforms) | ||
else: | ||
raise ValueError(f"Unknown stage {stage}") | ||
|
||
|
||
@torch.no_grad() | ||
def _validate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, batch_size=None, header=None): | ||
"""Helper function to compute various metrics (epe, etc.) for a model on a given dataset. | ||
|
||
We process as many samples as possible with ddp, and process the rest on a single worker. | ||
""" | ||
batch_size = batch_size or args.batch_size | ||
|
||
model.eval() | ||
|
||
sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True) | ||
val_loader = torch.utils.data.DataLoader( | ||
val_dataset, | ||
sampler=sampler, | ||
batch_size=batch_size, | ||
pin_memory=True, | ||
num_workers=args.num_workers, | ||
) | ||
|
||
num_flow_updates = num_flow_updates or args.num_flow_updates | ||
|
||
def inner_loop(blob): | ||
if blob[0].dim() == 3: | ||
# input is not batched so we add an extra dim for consistency | ||
blob = [x[None, :, :, :] if x is not None else None for x in blob] | ||
|
||
image1, image2, flow_gt = blob[:3] | ||
valid_flow_mask = None if len(blob) == 3 else blob[-1] | ||
|
||
image1, image2 = image1.cuda(), image2.cuda() | ||
|
||
padder = utils.InputPadder(image1.shape, mode=padder_mode) | ||
image1, image2 = padder.pad(image1, image2) | ||
|
||
flow_predictions = model(image1, image2, num_flow_updates=num_flow_updates) | ||
flow_pred = flow_predictions[-1] | ||
flow_pred = padder.unpad(flow_pred).cpu() | ||
|
||
metrics, num_pixels_tot = utils.compute_metrics(flow_pred, flow_gt, valid_flow_mask) | ||
|
||
# We compute per-pixel epe (epe) and per-image epe (called f1-epe in RAFT paper). | ||
# per-pixel epe: average epe of all pixels of all images | ||
# per-image epe: average epe on each image independently, then average over images | ||
for name in ("epe", "1px", "3px", "5px", "f1"): # f1 is called f1-all in paper | ||
logger.meters[name].update(metrics[name], n=num_pixels_tot) | ||
logger.meters["per_image_epe"].update(metrics["epe"], n=batch_size) | ||
|
||
logger = utils.MetricLogger() | ||
for meter_name in ("epe", "1px", "3px", "5px", "per_image_epe", "f1"): | ||
logger.add_meter(meter_name, fmt="{global_avg:.4f}") | ||
|
||
num_processed_samples = 0 | ||
for blob in logger.log_every(val_loader, header=header, print_freq=None): | ||
inner_loop(blob) | ||
num_processed_samples += blob[0].shape[0] # batch size | ||
|
||
num_processed_samples = utils.reduce_across_processes(num_processed_samples) | ||
print( | ||
f"Batch-processed {num_processed_samples} / {len(val_dataset)} samples. " | ||
"Going to process the remaining samples individually, if any." | ||
) | ||
|
||
if args.rank == 0: # we only need to process the rest on a single worker | ||
for i in range(num_processed_samples, len(val_dataset)): | ||
inner_loop(val_dataset[i]) | ||
|
||
logger.synchronize_between_processes() | ||
print(header, logger) | ||
|
||
|
||
def validate(model, args): | ||
val_datasets = args.val_dataset or [] | ||
for name in val_datasets: | ||
if name == "kitti": | ||
# Kitti has different image sizes so we need to individually pad them, we can't batch. | ||
# see comment in InputPadder | ||
if args.batch_size != 1 and args.rank == 0: | ||
warnings.warn( | ||
f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1." | ||
) | ||
|
||
val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=OpticalFlowPresetEval()) | ||
_validate( | ||
model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1 | ||
) | ||
elif name == "sintel": | ||
for pass_name in ("clean", "final"): | ||
val_dataset = Sintel( | ||
root=args.dataset_root, split="train", pass_name=pass_name, transforms=OpticalFlowPresetEval() | ||
) | ||
_validate( | ||
model, | ||
args, | ||
val_dataset, | ||
num_flow_updates=32, | ||
padder_mode="sintel", | ||
header=f"Sintel val {pass_name}", | ||
) | ||
else: | ||
warnings.warn(f"Can't validate on {val_dataset}, skipping.") | ||
|
||
|
||
def train_one_epoch(model, optimizer, scheduler, train_loader, logger, current_step, args): | ||
for data_blob in logger.log_every(train_loader): | ||
|
||
optimizer.zero_grad() | ||
|
||
image1, image2, flow_gt, valid_flow_mask = (x.cuda() for x in data_blob) | ||
flow_predictions = model(image1, image2, num_flow_updates=args.num_flow_updates) | ||
|
||
loss = utils.sequence_loss(flow_predictions, flow_gt, valid_flow_mask, args.gamma) | ||
metrics, _ = utils.compute_metrics(flow_predictions[-1], flow_gt, valid_flow_mask) | ||
|
||
metrics.pop("f1") | ||
logger.update(loss=loss, **metrics) | ||
|
||
loss.backward() | ||
|
||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) | ||
|
||
optimizer.step() | ||
scheduler.step() | ||
|
||
current_step += 1 | ||
|
||
if current_step == args.num_steps: | ||
return True, current_step | ||
|
||
return False, current_step | ||
|
||
|
||
def main(args): | ||
utils.setup_ddp(args) | ||
|
||
model = raft_small() if args.small else raft_large() | ||
model = model.to(args.local_rank) | ||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank]) | ||
|
||
if args.resume is not None: | ||
d = torch.load(args.resume, map_location="cpu") | ||
model.load_state_dict(d, strict=True) | ||
|
||
if args.train_dataset is None: | ||
# Set deterministic CUDNN algorithms, since they can affect epe a fair bit. | ||
torch.backends.cudnn.benchmark = False | ||
torch.backends.cudnn.deterministic = True | ||
validate(model, args) | ||
return | ||
|
||
print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") | ||
|
||
torch.backends.cudnn.benchmark = True | ||
|
||
model.train() | ||
if args.freeze_batch_norm: | ||
utils.freeze_batch_norm(model.module) | ||
|
||
train_dataset = get_train_dataset(args.train_dataset, args.dataset_root) | ||
|
||
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True) | ||
train_loader = torch.utils.data.DataLoader( | ||
train_dataset, | ||
sampler=sampler, | ||
batch_size=args.batch_size, | ||
pin_memory=True, | ||
num_workers=args.num_workers, | ||
) | ||
|
||
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.adamw_eps) | ||
|
||
scheduler = torch.optim.lr_scheduler.OneCycleLR( | ||
optimizer=optimizer, | ||
max_lr=args.lr, | ||
total_steps=args.num_steps + 100, | ||
pct_start=0.05, | ||
cycle_momentum=False, | ||
anneal_strategy="linear", | ||
) | ||
|
||
logger = utils.MetricLogger() | ||
|
||
done = False | ||
current_epoch = current_step = 0 | ||
while not done: | ||
print(f"EPOCH {current_epoch}") | ||
|
||
sampler.set_epoch(current_epoch) # needed, otherwise the data loading order would be the same for all epochs | ||
done, current_step = train_one_epoch( | ||
model=model, | ||
optimizer=optimizer, | ||
scheduler=scheduler, | ||
train_loader=train_loader, | ||
logger=logger, | ||
current_step=current_step, | ||
args=args, | ||
) | ||
|
||
# Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0 | ||
print(f"Epoch {current_epoch} done. ", logger) | ||
|
||
current_epoch += 1 | ||
|
||
if args.rank == 0: | ||
# TODO: Also save the optimizer and scheduler | ||
torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}_{current_epoch}.pth") | ||
torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}.pth") | ||
Comment on lines
+260
to
+261
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we also save the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should definitely do this. I think it's worth refactoring the script to have same functionality and structure as other reference scripts. Moreover we will need to link the ref scripts with the model prototype and add the @NicolasHug do you mind creating an issue for all the above so that we dont forget? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Opened #5056 |
||
|
||
if current_epoch % args.val_freq == 0 or done: | ||
validate(model, args) | ||
model.train() | ||
if args.freeze_batch_norm: | ||
utils.freeze_batch_norm(model.module) | ||
|
||
|
||
def get_args_parser(add_help=True): | ||
parser = argparse.ArgumentParser(add_help=add_help, description="Train or evaluate an optical-flow model.") | ||
parser.add_argument( | ||
"--name", | ||
default="raft", | ||
type=str, | ||
help="The name of the experiment - determines the name of the files where weights are saved.", | ||
) | ||
parser.add_argument( | ||
"--output-dir", default="checkpoints", type=str, help="Output dir where checkpoints will be stored." | ||
) | ||
parser.add_argument( | ||
"--resume", | ||
type=str, | ||
help="A path to previously saved weights. Used to re-start training from, or evaluate a pre-saved model.", | ||
) | ||
|
||
parser.add_argument("--num-workers", type=int, default=12, help="Number of workers for the data loading part.") | ||
|
||
parser.add_argument( | ||
"--train-dataset", | ||
type=str, | ||
help="The dataset to use for training. If not passed, only validation is performed (and you probably want to pass --resume).", | ||
) | ||
parser.add_argument("--val-dataset", type=str, nargs="+", help="The dataset(s) to use for validation.") | ||
parser.add_argument("--val-freq", type=int, default=2, help="Validate every X epochs") | ||
# TODO: eventually, it might be preferable to support epochs instead of num_steps. | ||
# Keeping it this way for now to reproduce results more easily. | ||
parser.add_argument("--num-steps", type=int, default=100000, help="The total number of steps (updates) to train.") | ||
parser.add_argument("--batch-size", type=int, default=6) | ||
|
||
parser.add_argument("--lr", type=float, default=0.00002, help="Learning rate for AdamW optimizer") | ||
parser.add_argument("--weight-decay", type=float, default=0.00005, help="Weight decay for AdamW optimizer") | ||
parser.add_argument("--adamw-eps", type=float, default=1e-8, help="eps value for AdamW optimizer") | ||
|
||
parser.add_argument( | ||
"--freeze-batch-norm", action="store_true", help="Set BatchNorm modules of the model in eval mode." | ||
) | ||
|
||
parser.add_argument("--small", action="store_true", help="Use the 'small' RAFT architecture.") | ||
|
||
parser.add_argument( | ||
"--num_flow_updates", | ||
type=int, | ||
default=12, | ||
help="number of updates (or 'iters') in the update operator of the model.", | ||
) | ||
|
||
parser.add_argument("--gamma", type=float, default=0.8, help="exponential weighting for loss. Must be < 1.") | ||
|
||
parser.add_argument("--dist-url", default="env://", help="URL used to set up distributed training") | ||
|
||
parser.add_argument( | ||
"--dataset-root", | ||
help="Root folder where the datasets are stored. Will be passed as the 'root' parameter of the datasets.", | ||
required=True, | ||
) | ||
|
||
return parser | ||
|
||
|
||
if __name__ == "__main__": | ||
args = get_args_parser().parse_args() | ||
Path(args.output_dir).mkdir(exist_ok=True) | ||
main(args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok with me. So you added support for
__mul__
in those datasets?