Skip to content

Add training reference for optical flow models #5027

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 7, 2021
334 changes: 334 additions & 0 deletions references/optical_flow/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,334 @@
import argparse
import warnings
from pathlib import Path

import torch
import utils
from presets import OpticalFlowPresetTrain, OpticalFlowPresetEval
from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K
from torchvision.models.optical_flow import raft_large, raft_small


def get_train_dataset(stage, dataset_root):
if stage == "chairs":
transforms = OpticalFlowPresetTrain(crop_size=(368, 496), min_scale=0.1, max_scale=1.0, do_flip=True)
return FlyingChairs(root=dataset_root, split="train", transforms=transforms)
elif stage == "things":
transforms = OpticalFlowPresetTrain(crop_size=(400, 720), min_scale=-0.4, max_scale=0.8, do_flip=True)
return FlyingThings3D(root=dataset_root, split="train", pass_name="both", transforms=transforms)
elif stage == "sintel_SKH": # S + K + H as from paper
crop_size = (368, 768)
transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.2, max_scale=0.6, do_flip=True)

things_clean = FlyingThings3D(root=dataset_root, split="train", pass_name="clean", transforms=transforms)
sintel = Sintel(root=dataset_root, split="train", pass_name="both", transforms=transforms)

kitti_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.3, max_scale=0.5, do_flip=True)
kitti = KittiFlow(root=dataset_root, split="train", transforms=kitti_transforms)

hd1k_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.5, max_scale=0.2, do_flip=True)
hd1k = HD1K(root=dataset_root, split="train", transforms=hd1k_transforms)

# As future improvement, we could probably be using a distributed sampler here
# The distribution is S(.71), T(.135), K(.135), H(.02)
return 100 * sintel + 200 * kitti + 5 * hd1k + things_clean
Comment on lines +32 to +34
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok with me. So you added support for __mul__ in those datasets?

elif stage == "kitti":
transforms = OpticalFlowPresetTrain(
# resize and crop params
crop_size=(288, 960),
min_scale=-0.2,
max_scale=0.4,
stretch_prob=0,
# flip params
do_flip=False,
# jitter params
brightness=0.3,
contrast=0.3,
saturation=0.3,
hue=0.3 / 3.14,
asymmetric_jitter_prob=0,
)
return KittiFlow(root=dataset_root, split="train", transforms=transforms)
else:
raise ValueError(f"Unknown stage {stage}")


@torch.no_grad()
def _validate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, batch_size=None, header=None):
"""Helper function to compute various metrics (epe, etc.) for a model on a given dataset.

We process as many samples as possible with ddp, and process the rest on a single worker.
"""
batch_size = batch_size or args.batch_size

model.eval()

sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
val_loader = torch.utils.data.DataLoader(
val_dataset,
sampler=sampler,
batch_size=batch_size,
pin_memory=True,
num_workers=args.num_workers,
)

num_flow_updates = num_flow_updates or args.num_flow_updates

def inner_loop(blob):
if blob[0].dim() == 3:
# input is not batched so we add an extra dim for consistency
blob = [x[None, :, :, :] if x is not None else None for x in blob]

image1, image2, flow_gt = blob[:3]
valid_flow_mask = None if len(blob) == 3 else blob[-1]

image1, image2 = image1.cuda(), image2.cuda()

padder = utils.InputPadder(image1.shape, mode=padder_mode)
image1, image2 = padder.pad(image1, image2)

flow_predictions = model(image1, image2, num_flow_updates=num_flow_updates)
flow_pred = flow_predictions[-1]
flow_pred = padder.unpad(flow_pred).cpu()

metrics, num_pixels_tot = utils.compute_metrics(flow_pred, flow_gt, valid_flow_mask)

# We compute per-pixel epe (epe) and per-image epe (called f1-epe in RAFT paper).
# per-pixel epe: average epe of all pixels of all images
# per-image epe: average epe on each image independently, then average over images
for name in ("epe", "1px", "3px", "5px", "f1"): # f1 is called f1-all in paper
logger.meters[name].update(metrics[name], n=num_pixels_tot)
logger.meters["per_image_epe"].update(metrics["epe"], n=batch_size)

logger = utils.MetricLogger()
for meter_name in ("epe", "1px", "3px", "5px", "per_image_epe", "f1"):
logger.add_meter(meter_name, fmt="{global_avg:.4f}")

num_processed_samples = 0
for blob in logger.log_every(val_loader, header=header, print_freq=None):
inner_loop(blob)
num_processed_samples += blob[0].shape[0] # batch size

num_processed_samples = utils.reduce_across_processes(num_processed_samples)
print(
f"Batch-processed {num_processed_samples} / {len(val_dataset)} samples. "
"Going to process the remaining samples individually, if any."
)

if args.rank == 0: # we only need to process the rest on a single worker
for i in range(num_processed_samples, len(val_dataset)):
inner_loop(val_dataset[i])

logger.synchronize_between_processes()
print(header, logger)


def validate(model, args):
val_datasets = args.val_dataset or []
for name in val_datasets:
if name == "kitti":
# Kitti has different image sizes so we need to individually pad them, we can't batch.
# see comment in InputPadder
if args.batch_size != 1 and args.rank == 0:
warnings.warn(
f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1."
)

val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=OpticalFlowPresetEval())
_validate(
model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1
)
elif name == "sintel":
for pass_name in ("clean", "final"):
val_dataset = Sintel(
root=args.dataset_root, split="train", pass_name=pass_name, transforms=OpticalFlowPresetEval()
)
_validate(
model,
args,
val_dataset,
num_flow_updates=32,
padder_mode="sintel",
header=f"Sintel val {pass_name}",
)
else:
warnings.warn(f"Can't validate on {val_dataset}, skipping.")


def train_one_epoch(model, optimizer, scheduler, train_loader, logger, current_step, args):
for data_blob in logger.log_every(train_loader):

optimizer.zero_grad()

image1, image2, flow_gt, valid_flow_mask = (x.cuda() for x in data_blob)
flow_predictions = model(image1, image2, num_flow_updates=args.num_flow_updates)

loss = utils.sequence_loss(flow_predictions, flow_gt, valid_flow_mask, args.gamma)
metrics, _ = utils.compute_metrics(flow_predictions[-1], flow_gt, valid_flow_mask)

metrics.pop("f1")
logger.update(loss=loss, **metrics)

loss.backward()

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)

optimizer.step()
scheduler.step()

current_step += 1

if current_step == args.num_steps:
return True, current_step

return False, current_step


def main(args):
utils.setup_ddp(args)

model = raft_small() if args.small else raft_large()
model = model.to(args.local_rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])

if args.resume is not None:
d = torch.load(args.resume, map_location="cpu")
model.load_state_dict(d, strict=True)

if args.train_dataset is None:
# Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
validate(model, args)
return

print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

torch.backends.cudnn.benchmark = True

model.train()
if args.freeze_batch_norm:
utils.freeze_batch_norm(model.module)

train_dataset = get_train_dataset(args.train_dataset, args.dataset_root)

sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True)
train_loader = torch.utils.data.DataLoader(
train_dataset,
sampler=sampler,
batch_size=args.batch_size,
pin_memory=True,
num_workers=args.num_workers,
)

optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.adamw_eps)

scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer=optimizer,
max_lr=args.lr,
total_steps=args.num_steps + 100,
pct_start=0.05,
cycle_momentum=False,
anneal_strategy="linear",
)

logger = utils.MetricLogger()

done = False
current_epoch = current_step = 0
while not done:
print(f"EPOCH {current_epoch}")

sampler.set_epoch(current_epoch) # needed, otherwise the data loading order would be the same for all epochs
done, current_step = train_one_epoch(
model=model,
optimizer=optimizer,
scheduler=scheduler,
train_loader=train_loader,
logger=logger,
current_step=current_step,
args=args,
)

# Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
print(f"Epoch {current_epoch} done. ", logger)

current_epoch += 1

if args.rank == 0:
# TODO: Also save the optimizer and scheduler
torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}_{current_epoch}.pth")
torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}.pth")
Comment on lines +260 to +261
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also save the optimizer and the scheduler so that we can resume training? This is what we do in the other reference scripts

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should definitely do this. I think it's worth refactoring the script to have same functionality and structure as other reference scripts. Moreover we will need to link the ref scripts with the model prototype and add the --weights feature switch.

@NicolasHug do you mind creating an issue for all the above so that we dont forget?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened #5056


if current_epoch % args.val_freq == 0 or done:
validate(model, args)
model.train()
if args.freeze_batch_norm:
utils.freeze_batch_norm(model.module)


def get_args_parser(add_help=True):
parser = argparse.ArgumentParser(add_help=add_help, description="Train or evaluate an optical-flow model.")
parser.add_argument(
"--name",
default="raft",
type=str,
help="The name of the experiment - determines the name of the files where weights are saved.",
)
parser.add_argument(
"--output-dir", default="checkpoints", type=str, help="Output dir where checkpoints will be stored."
)
parser.add_argument(
"--resume",
type=str,
help="A path to previously saved weights. Used to re-start training from, or evaluate a pre-saved model.",
)

parser.add_argument("--num-workers", type=int, default=12, help="Number of workers for the data loading part.")

parser.add_argument(
"--train-dataset",
type=str,
help="The dataset to use for training. If not passed, only validation is performed (and you probably want to pass --resume).",
)
parser.add_argument("--val-dataset", type=str, nargs="+", help="The dataset(s) to use for validation.")
parser.add_argument("--val-freq", type=int, default=2, help="Validate every X epochs")
# TODO: eventually, it might be preferable to support epochs instead of num_steps.
# Keeping it this way for now to reproduce results more easily.
parser.add_argument("--num-steps", type=int, default=100000, help="The total number of steps (updates) to train.")
parser.add_argument("--batch-size", type=int, default=6)

parser.add_argument("--lr", type=float, default=0.00002, help="Learning rate for AdamW optimizer")
parser.add_argument("--weight-decay", type=float, default=0.00005, help="Weight decay for AdamW optimizer")
parser.add_argument("--adamw-eps", type=float, default=1e-8, help="eps value for AdamW optimizer")

parser.add_argument(
"--freeze-batch-norm", action="store_true", help="Set BatchNorm modules of the model in eval mode."
)

parser.add_argument("--small", action="store_true", help="Use the 'small' RAFT architecture.")

parser.add_argument(
"--num_flow_updates",
type=int,
default=12,
help="number of updates (or 'iters') in the update operator of the model.",
)

parser.add_argument("--gamma", type=float, default=0.8, help="exponential weighting for loss. Must be < 1.")

parser.add_argument("--dist-url", default="env://", help="URL used to set up distributed training")

parser.add_argument(
"--dataset-root",
help="Root folder where the datasets are stored. Will be passed as the 'root' parameter of the datasets.",
required=True,
)

return parser


if __name__ == "__main__":
args = get_args_parser().parse_args()
Path(args.output_dir).mkdir(exist_ok=True)
main(args)
Loading