diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py new file mode 100644 index 00000000000..eaf03fbe4f3 --- /dev/null +++ b/references/optical_flow/train.py @@ -0,0 +1,334 @@ +import argparse +import warnings +from pathlib import Path + +import torch +import utils +from presets import OpticalFlowPresetTrain, OpticalFlowPresetEval +from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K +from torchvision.models.optical_flow import raft_large, raft_small + + +def get_train_dataset(stage, dataset_root): + if stage == "chairs": + transforms = OpticalFlowPresetTrain(crop_size=(368, 496), min_scale=0.1, max_scale=1.0, do_flip=True) + return FlyingChairs(root=dataset_root, split="train", transforms=transforms) + elif stage == "things": + transforms = OpticalFlowPresetTrain(crop_size=(400, 720), min_scale=-0.4, max_scale=0.8, do_flip=True) + return FlyingThings3D(root=dataset_root, split="train", pass_name="both", transforms=transforms) + elif stage == "sintel_SKH": # S + K + H as from paper + crop_size = (368, 768) + transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.2, max_scale=0.6, do_flip=True) + + things_clean = FlyingThings3D(root=dataset_root, split="train", pass_name="clean", transforms=transforms) + sintel = Sintel(root=dataset_root, split="train", pass_name="both", transforms=transforms) + + kitti_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.3, max_scale=0.5, do_flip=True) + kitti = KittiFlow(root=dataset_root, split="train", transforms=kitti_transforms) + + hd1k_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.5, max_scale=0.2, do_flip=True) + hd1k = HD1K(root=dataset_root, split="train", transforms=hd1k_transforms) + + # As future improvement, we could probably be using a distributed sampler here + # The distribution is S(.71), T(.135), K(.135), H(.02) + return 100 * sintel + 200 * kitti + 5 * hd1k + things_clean + elif stage == "kitti": + transforms = OpticalFlowPresetTrain( + # resize and crop params + crop_size=(288, 960), + min_scale=-0.2, + max_scale=0.4, + stretch_prob=0, + # flip params + do_flip=False, + # jitter params + brightness=0.3, + contrast=0.3, + saturation=0.3, + hue=0.3 / 3.14, + asymmetric_jitter_prob=0, + ) + return KittiFlow(root=dataset_root, split="train", transforms=transforms) + else: + raise ValueError(f"Unknown stage {stage}") + + +@torch.no_grad() +def _validate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, batch_size=None, header=None): + """Helper function to compute various metrics (epe, etc.) for a model on a given dataset. + + We process as many samples as possible with ddp, and process the rest on a single worker. + """ + batch_size = batch_size or args.batch_size + + model.eval() + + sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True) + val_loader = torch.utils.data.DataLoader( + val_dataset, + sampler=sampler, + batch_size=batch_size, + pin_memory=True, + num_workers=args.num_workers, + ) + + num_flow_updates = num_flow_updates or args.num_flow_updates + + def inner_loop(blob): + if blob[0].dim() == 3: + # input is not batched so we add an extra dim for consistency + blob = [x[None, :, :, :] if x is not None else None for x in blob] + + image1, image2, flow_gt = blob[:3] + valid_flow_mask = None if len(blob) == 3 else blob[-1] + + image1, image2 = image1.cuda(), image2.cuda() + + padder = utils.InputPadder(image1.shape, mode=padder_mode) + image1, image2 = padder.pad(image1, image2) + + flow_predictions = model(image1, image2, num_flow_updates=num_flow_updates) + flow_pred = flow_predictions[-1] + flow_pred = padder.unpad(flow_pred).cpu() + + metrics, num_pixels_tot = utils.compute_metrics(flow_pred, flow_gt, valid_flow_mask) + + # We compute per-pixel epe (epe) and per-image epe (called f1-epe in RAFT paper). + # per-pixel epe: average epe of all pixels of all images + # per-image epe: average epe on each image independently, then average over images + for name in ("epe", "1px", "3px", "5px", "f1"): # f1 is called f1-all in paper + logger.meters[name].update(metrics[name], n=num_pixels_tot) + logger.meters["per_image_epe"].update(metrics["epe"], n=batch_size) + + logger = utils.MetricLogger() + for meter_name in ("epe", "1px", "3px", "5px", "per_image_epe", "f1"): + logger.add_meter(meter_name, fmt="{global_avg:.4f}") + + num_processed_samples = 0 + for blob in logger.log_every(val_loader, header=header, print_freq=None): + inner_loop(blob) + num_processed_samples += blob[0].shape[0] # batch size + + num_processed_samples = utils.reduce_across_processes(num_processed_samples) + print( + f"Batch-processed {num_processed_samples} / {len(val_dataset)} samples. " + "Going to process the remaining samples individually, if any." + ) + + if args.rank == 0: # we only need to process the rest on a single worker + for i in range(num_processed_samples, len(val_dataset)): + inner_loop(val_dataset[i]) + + logger.synchronize_between_processes() + print(header, logger) + + +def validate(model, args): + val_datasets = args.val_dataset or [] + for name in val_datasets: + if name == "kitti": + # Kitti has different image sizes so we need to individually pad them, we can't batch. + # see comment in InputPadder + if args.batch_size != 1 and args.rank == 0: + warnings.warn( + f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1." + ) + + val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=OpticalFlowPresetEval()) + _validate( + model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1 + ) + elif name == "sintel": + for pass_name in ("clean", "final"): + val_dataset = Sintel( + root=args.dataset_root, split="train", pass_name=pass_name, transforms=OpticalFlowPresetEval() + ) + _validate( + model, + args, + val_dataset, + num_flow_updates=32, + padder_mode="sintel", + header=f"Sintel val {pass_name}", + ) + else: + warnings.warn(f"Can't validate on {val_dataset}, skipping.") + + +def train_one_epoch(model, optimizer, scheduler, train_loader, logger, current_step, args): + for data_blob in logger.log_every(train_loader): + + optimizer.zero_grad() + + image1, image2, flow_gt, valid_flow_mask = (x.cuda() for x in data_blob) + flow_predictions = model(image1, image2, num_flow_updates=args.num_flow_updates) + + loss = utils.sequence_loss(flow_predictions, flow_gt, valid_flow_mask, args.gamma) + metrics, _ = utils.compute_metrics(flow_predictions[-1], flow_gt, valid_flow_mask) + + metrics.pop("f1") + logger.update(loss=loss, **metrics) + + loss.backward() + + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) + + optimizer.step() + scheduler.step() + + current_step += 1 + + if current_step == args.num_steps: + return True, current_step + + return False, current_step + + +def main(args): + utils.setup_ddp(args) + + model = raft_small() if args.small else raft_large() + model = model.to(args.local_rank) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank]) + + if args.resume is not None: + d = torch.load(args.resume, map_location="cpu") + model.load_state_dict(d, strict=True) + + if args.train_dataset is None: + # Set deterministic CUDNN algorithms, since they can affect epe a fair bit. + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + validate(model, args) + return + + print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") + + torch.backends.cudnn.benchmark = True + + model.train() + if args.freeze_batch_norm: + utils.freeze_batch_norm(model.module) + + train_dataset = get_train_dataset(args.train_dataset, args.dataset_root) + + sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True) + train_loader = torch.utils.data.DataLoader( + train_dataset, + sampler=sampler, + batch_size=args.batch_size, + pin_memory=True, + num_workers=args.num_workers, + ) + + optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.adamw_eps) + + scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer=optimizer, + max_lr=args.lr, + total_steps=args.num_steps + 100, + pct_start=0.05, + cycle_momentum=False, + anneal_strategy="linear", + ) + + logger = utils.MetricLogger() + + done = False + current_epoch = current_step = 0 + while not done: + print(f"EPOCH {current_epoch}") + + sampler.set_epoch(current_epoch) # needed, otherwise the data loading order would be the same for all epochs + done, current_step = train_one_epoch( + model=model, + optimizer=optimizer, + scheduler=scheduler, + train_loader=train_loader, + logger=logger, + current_step=current_step, + args=args, + ) + + # Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0 + print(f"Epoch {current_epoch} done. ", logger) + + current_epoch += 1 + + if args.rank == 0: + # TODO: Also save the optimizer and scheduler + torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}_{current_epoch}.pth") + torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}.pth") + + if current_epoch % args.val_freq == 0 or done: + validate(model, args) + model.train() + if args.freeze_batch_norm: + utils.freeze_batch_norm(model.module) + + +def get_args_parser(add_help=True): + parser = argparse.ArgumentParser(add_help=add_help, description="Train or evaluate an optical-flow model.") + parser.add_argument( + "--name", + default="raft", + type=str, + help="The name of the experiment - determines the name of the files where weights are saved.", + ) + parser.add_argument( + "--output-dir", default="checkpoints", type=str, help="Output dir where checkpoints will be stored." + ) + parser.add_argument( + "--resume", + type=str, + help="A path to previously saved weights. Used to re-start training from, or evaluate a pre-saved model.", + ) + + parser.add_argument("--num-workers", type=int, default=12, help="Number of workers for the data loading part.") + + parser.add_argument( + "--train-dataset", + type=str, + help="The dataset to use for training. If not passed, only validation is performed (and you probably want to pass --resume).", + ) + parser.add_argument("--val-dataset", type=str, nargs="+", help="The dataset(s) to use for validation.") + parser.add_argument("--val-freq", type=int, default=2, help="Validate every X epochs") + # TODO: eventually, it might be preferable to support epochs instead of num_steps. + # Keeping it this way for now to reproduce results more easily. + parser.add_argument("--num-steps", type=int, default=100000, help="The total number of steps (updates) to train.") + parser.add_argument("--batch-size", type=int, default=6) + + parser.add_argument("--lr", type=float, default=0.00002, help="Learning rate for AdamW optimizer") + parser.add_argument("--weight-decay", type=float, default=0.00005, help="Weight decay for AdamW optimizer") + parser.add_argument("--adamw-eps", type=float, default=1e-8, help="eps value for AdamW optimizer") + + parser.add_argument( + "--freeze-batch-norm", action="store_true", help="Set BatchNorm modules of the model in eval mode." + ) + + parser.add_argument("--small", action="store_true", help="Use the 'small' RAFT architecture.") + + parser.add_argument( + "--num_flow_updates", + type=int, + default=12, + help="number of updates (or 'iters') in the update operator of the model.", + ) + + parser.add_argument("--gamma", type=float, default=0.8, help="exponential weighting for loss. Must be < 1.") + + parser.add_argument("--dist-url", default="env://", help="URL used to set up distributed training") + + parser.add_argument( + "--dataset-root", + help="Root folder where the datasets are stored. Will be passed as the 'root' parameter of the datasets.", + required=True, + ) + + return parser + + +if __name__ == "__main__": + args = get_args_parser().parse_args() + Path(args.output_dir).mkdir(exist_ok=True) + main(args) diff --git a/references/optical_flow/utils.py b/references/optical_flow/utils.py new file mode 100644 index 00000000000..e3643a91663 --- /dev/null +++ b/references/optical_flow/utils.py @@ -0,0 +1,282 @@ +import datetime +import os +import time +from collections import defaultdict +from collections import deque + +import torch +import torch.distributed as dist +import torch.nn.functional as F + + +class SmoothedValue: + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt="{median:.4f} ({global_avg:.4f})"): + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + t = reduce_across_processes([self.count, self.total]) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value + ) + + +class MetricLogger: + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'") + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append(f"{name}: {str(meter)}") + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, **kwargs): + self.meters[name] = SmoothedValue(**kwargs) + + def log_every(self, iterable, print_freq=5, header=None): + i = 0 + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" + if torch.cuda.is_available(): + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + "max mem: {memory:.0f}", + ] + ) + else: + log_msg = self.delimiter.join( + [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"] + ) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if print_freq is not None and i % print_freq == 0: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + print( + log_msg.format( + i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time) + ) + ) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print(f"{header} Total time: {total_time_str}") + + +def compute_metrics(flow_pred, flow_gt, valid_flow_mask=None): + + epe = ((flow_pred - flow_gt) ** 2).sum(dim=1).sqrt() + flow_norm = (flow_gt ** 2).sum(dim=1).sqrt() + + if valid_flow_mask is not None: + epe = epe[valid_flow_mask] + flow_norm = flow_norm[valid_flow_mask] + + relative_epe = epe / flow_norm + + metrics = { + "epe": epe.mean().item(), + "1px": (epe < 1).float().mean().item(), + "3px": (epe < 3).float().mean().item(), + "5px": (epe < 5).float().mean().item(), + "f1": ((epe > 3) & (relative_epe > 0.05)).float().mean().item() * 100, + } + return metrics, epe.numel() + + +def sequence_loss(flow_preds, flow_gt, valid_flow_mask, gamma=0.8, max_flow=400): + """Loss function defined over sequence of flow predictions""" + + if gamma > 1: + raise ValueError(f"Gamma should be < 1, got {gamma}.") + + # exlude invalid pixels and extremely large diplacements + flow_norm = torch.sum(flow_gt ** 2, dim=1).sqrt() + valid_flow_mask = valid_flow_mask & (flow_norm < max_flow) + + valid_flow_mask = valid_flow_mask[:, None, :, :] + + flow_preds = torch.stack(flow_preds) # shape = (num_flow_updates, batch_size, 2, H, W) + + abs_diff = (flow_preds - flow_gt).abs() + abs_diff = (abs_diff * valid_flow_mask).mean(axis=(1, 2, 3, 4)) + + num_predictions = flow_preds.shape[0] + weights = gamma ** torch.arange(num_predictions - 1, -1, -1).to(flow_gt.device) + flow_loss = (abs_diff * weights).sum() + + return flow_loss + + +class InputPadder: + """Pads images such that dimensions are divisible by 8""" + + # TODO: Ideally, this should be part of the eval transforms preset, instead + # of being part of the validation code. It's not obvious what a good + # solution would be, because we need to unpad the predicted flows according + # to the input images' size, and in some datasets (Kitti) images can have + # variable sizes. + + def __init__(self, dims, mode="sintel"): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 + pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 + if mode == "sintel": + self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2] + else: + self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode="replicate") for x in inputs] + + def unpad(self, x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] + return x[..., c[0] : c[1], c[2] : c[3]] + + +def _redefine_print(is_main): + """disables printing when not in main process""" + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_main or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def setup_ddp(args): + # Set the local_rank, rank, and world_size values as args fields + # This is done differently depending on how we're running the script. We + # currently support either torchrun or the custom run_with_submitit.py + # If you're confused (like I was), this might help a bit + # https://discuss.pytorch.org/t/what-is-the-difference-between-rank-and-local-rank/61940/2 + + if all(key in os.environ for key in ("LOCAL_RANK", "RANK", "WORLD_SIZE")): + # if we're here, the script was called with torchrun. Otherwise + # these args will be set already by the run_with_submitit script + args.local_rank = int(os.environ["LOCAL_RANK"]) + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + + elif "gpu" in args: + # if we're here, the script was called by run_with_submitit.py + args.local_rank = args.gpu + else: + raise ValueError(r"Sorry, I can't set up the distributed training ¯\_(ツ)_/¯.") + + _redefine_print(is_main=(args.rank == 0)) + + torch.cuda.set_device(args.local_rank) + dist.init_process_group( + backend="nccl", + rank=args.rank, + world_size=args.world_size, + init_method=args.dist_url, + ) + + +def reduce_across_processes(val): + t = torch.tensor(val, device="cuda") + dist.barrier() + dist.all_reduce(t) + return t + + +def freeze_batch_norm(model): + for m in model.modules(): + if isinstance(m, torch.nn.BatchNorm2d): + m.eval()