|
| 1 | +import argparse |
| 2 | +import warnings |
| 3 | +from pathlib import Path |
| 4 | + |
| 5 | +import torch |
| 6 | +import utils |
| 7 | +from presets import OpticalFlowPresetTrain, OpticalFlowPresetEval |
| 8 | +from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K |
| 9 | +from torchvision.models.optical_flow import raft_large, raft_small |
| 10 | + |
| 11 | + |
| 12 | +def get_train_dataset(stage, dataset_root): |
| 13 | + if stage == "chairs": |
| 14 | + transforms = OpticalFlowPresetTrain(crop_size=(368, 496), min_scale=0.1, max_scale=1.0, do_flip=True) |
| 15 | + return FlyingChairs(root=dataset_root, split="train", transforms=transforms) |
| 16 | + elif stage == "things": |
| 17 | + transforms = OpticalFlowPresetTrain(crop_size=(400, 720), min_scale=-0.4, max_scale=0.8, do_flip=True) |
| 18 | + return FlyingThings3D(root=dataset_root, split="train", pass_name="both", transforms=transforms) |
| 19 | + elif stage == "sintel_SKH": # S + K + H as from paper |
| 20 | + crop_size = (368, 768) |
| 21 | + transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.2, max_scale=0.6, do_flip=True) |
| 22 | + |
| 23 | + things_clean = FlyingThings3D(root=dataset_root, split="train", pass_name="clean", transforms=transforms) |
| 24 | + sintel = Sintel(root=dataset_root, split="train", pass_name="both", transforms=transforms) |
| 25 | + |
| 26 | + kitti_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.3, max_scale=0.5, do_flip=True) |
| 27 | + kitti = KittiFlow(root=dataset_root, split="train", transforms=kitti_transforms) |
| 28 | + |
| 29 | + hd1k_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.5, max_scale=0.2, do_flip=True) |
| 30 | + hd1k = HD1K(root=dataset_root, split="train", transforms=hd1k_transforms) |
| 31 | + |
| 32 | + # As future improvement, we could probably be using a distributed sampler here |
| 33 | + # The distribution is S(.71), T(.135), K(.135), H(.02) |
| 34 | + return 100 * sintel + 200 * kitti + 5 * hd1k + things_clean |
| 35 | + elif stage == "kitti": |
| 36 | + transforms = OpticalFlowPresetTrain( |
| 37 | + # resize and crop params |
| 38 | + crop_size=(288, 960), |
| 39 | + min_scale=-0.2, |
| 40 | + max_scale=0.4, |
| 41 | + stretch_prob=0, |
| 42 | + # flip params |
| 43 | + do_flip=False, |
| 44 | + # jitter params |
| 45 | + brightness=0.3, |
| 46 | + contrast=0.3, |
| 47 | + saturation=0.3, |
| 48 | + hue=0.3 / 3.14, |
| 49 | + asymmetric_jitter_prob=0, |
| 50 | + ) |
| 51 | + return KittiFlow(root=dataset_root, split="train", transforms=transforms) |
| 52 | + else: |
| 53 | + raise ValueError(f"Unknown stage {stage}") |
| 54 | + |
| 55 | + |
| 56 | +@torch.no_grad() |
| 57 | +def _validate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, batch_size=None, header=None): |
| 58 | + """Helper function to compute various metrics (epe, etc.) for a model on a given dataset. |
| 59 | +
|
| 60 | + We process as many samples as possible with ddp, and process the rest on a single worker. |
| 61 | + """ |
| 62 | + batch_size = batch_size or args.batch_size |
| 63 | + |
| 64 | + model.eval() |
| 65 | + |
| 66 | + sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True) |
| 67 | + val_loader = torch.utils.data.DataLoader( |
| 68 | + val_dataset, |
| 69 | + sampler=sampler, |
| 70 | + batch_size=batch_size, |
| 71 | + pin_memory=True, |
| 72 | + num_workers=args.num_workers, |
| 73 | + ) |
| 74 | + |
| 75 | + num_flow_updates = num_flow_updates or args.num_flow_updates |
| 76 | + |
| 77 | + def inner_loop(blob): |
| 78 | + if blob[0].dim() == 3: |
| 79 | + # input is not batched so we add an extra dim for consistency |
| 80 | + blob = [x[None, :, :, :] if x is not None else None for x in blob] |
| 81 | + |
| 82 | + image1, image2, flow_gt = blob[:3] |
| 83 | + valid_flow_mask = None if len(blob) == 3 else blob[-1] |
| 84 | + |
| 85 | + image1, image2 = image1.cuda(), image2.cuda() |
| 86 | + |
| 87 | + padder = utils.InputPadder(image1.shape, mode=padder_mode) |
| 88 | + image1, image2 = padder.pad(image1, image2) |
| 89 | + |
| 90 | + flow_predictions = model(image1, image2, num_flow_updates=num_flow_updates) |
| 91 | + flow_pred = flow_predictions[-1] |
| 92 | + flow_pred = padder.unpad(flow_pred).cpu() |
| 93 | + |
| 94 | + metrics, num_pixels_tot = utils.compute_metrics(flow_pred, flow_gt, valid_flow_mask) |
| 95 | + |
| 96 | + # We compute per-pixel epe (epe) and per-image epe (called f1-epe in RAFT paper). |
| 97 | + # per-pixel epe: average epe of all pixels of all images |
| 98 | + # per-image epe: average epe on each image independently, then average over images |
| 99 | + for name in ("epe", "1px", "3px", "5px", "f1"): # f1 is called f1-all in paper |
| 100 | + logger.meters[name].update(metrics[name], n=num_pixels_tot) |
| 101 | + logger.meters["per_image_epe"].update(metrics["epe"], n=batch_size) |
| 102 | + |
| 103 | + logger = utils.MetricLogger() |
| 104 | + for meter_name in ("epe", "1px", "3px", "5px", "per_image_epe", "f1"): |
| 105 | + logger.add_meter(meter_name, fmt="{global_avg:.4f}") |
| 106 | + |
| 107 | + num_processed_samples = 0 |
| 108 | + for blob in logger.log_every(val_loader, header=header, print_freq=None): |
| 109 | + inner_loop(blob) |
| 110 | + num_processed_samples += blob[0].shape[0] # batch size |
| 111 | + |
| 112 | + num_processed_samples = utils.reduce_across_processes(num_processed_samples) |
| 113 | + print( |
| 114 | + f"Batch-processed {num_processed_samples} / {len(val_dataset)} samples. " |
| 115 | + "Going to process the remaining samples individually, if any." |
| 116 | + ) |
| 117 | + |
| 118 | + if args.rank == 0: # we only need to process the rest on a single worker |
| 119 | + for i in range(num_processed_samples, len(val_dataset)): |
| 120 | + inner_loop(val_dataset[i]) |
| 121 | + |
| 122 | + logger.synchronize_between_processes() |
| 123 | + print(header, logger) |
| 124 | + |
| 125 | + |
| 126 | +def validate(model, args): |
| 127 | + val_datasets = args.val_dataset or [] |
| 128 | + for name in val_datasets: |
| 129 | + if name == "kitti": |
| 130 | + # Kitti has different image sizes so we need to individually pad them, we can't batch. |
| 131 | + # see comment in InputPadder |
| 132 | + if args.batch_size != 1 and args.rank == 0: |
| 133 | + warnings.warn( |
| 134 | + f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1." |
| 135 | + ) |
| 136 | + |
| 137 | + val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=OpticalFlowPresetEval()) |
| 138 | + _validate( |
| 139 | + model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1 |
| 140 | + ) |
| 141 | + elif name == "sintel": |
| 142 | + for pass_name in ("clean", "final"): |
| 143 | + val_dataset = Sintel( |
| 144 | + root=args.dataset_root, split="train", pass_name=pass_name, transforms=OpticalFlowPresetEval() |
| 145 | + ) |
| 146 | + _validate( |
| 147 | + model, |
| 148 | + args, |
| 149 | + val_dataset, |
| 150 | + num_flow_updates=32, |
| 151 | + padder_mode="sintel", |
| 152 | + header=f"Sintel val {pass_name}", |
| 153 | + ) |
| 154 | + else: |
| 155 | + warnings.warn(f"Can't validate on {val_dataset}, skipping.") |
| 156 | + |
| 157 | + |
| 158 | +def train_one_epoch(model, optimizer, scheduler, train_loader, logger, current_step, args): |
| 159 | + for data_blob in logger.log_every(train_loader): |
| 160 | + |
| 161 | + optimizer.zero_grad() |
| 162 | + |
| 163 | + image1, image2, flow_gt, valid_flow_mask = (x.cuda() for x in data_blob) |
| 164 | + flow_predictions = model(image1, image2, num_flow_updates=args.num_flow_updates) |
| 165 | + |
| 166 | + loss = utils.sequence_loss(flow_predictions, flow_gt, valid_flow_mask, args.gamma) |
| 167 | + metrics, _ = utils.compute_metrics(flow_predictions[-1], flow_gt, valid_flow_mask) |
| 168 | + |
| 169 | + metrics.pop("f1") |
| 170 | + logger.update(loss=loss, **metrics) |
| 171 | + |
| 172 | + loss.backward() |
| 173 | + |
| 174 | + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) |
| 175 | + |
| 176 | + optimizer.step() |
| 177 | + scheduler.step() |
| 178 | + |
| 179 | + current_step += 1 |
| 180 | + |
| 181 | + if current_step == args.num_steps: |
| 182 | + return True, current_step |
| 183 | + |
| 184 | + return False, current_step |
| 185 | + |
| 186 | + |
| 187 | +def main(args): |
| 188 | + utils.setup_ddp(args) |
| 189 | + |
| 190 | + model = raft_small() if args.small else raft_large() |
| 191 | + model = model.to(args.local_rank) |
| 192 | + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank]) |
| 193 | + |
| 194 | + if args.resume is not None: |
| 195 | + d = torch.load(args.resume, map_location="cpu") |
| 196 | + model.load_state_dict(d, strict=True) |
| 197 | + |
| 198 | + if args.train_dataset is None: |
| 199 | + # Set deterministic CUDNN algorithms, since they can affect epe a fair bit. |
| 200 | + torch.backends.cudnn.benchmark = False |
| 201 | + torch.backends.cudnn.deterministic = True |
| 202 | + validate(model, args) |
| 203 | + return |
| 204 | + |
| 205 | + print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") |
| 206 | + |
| 207 | + torch.backends.cudnn.benchmark = True |
| 208 | + |
| 209 | + model.train() |
| 210 | + if args.freeze_batch_norm: |
| 211 | + utils.freeze_batch_norm(model.module) |
| 212 | + |
| 213 | + train_dataset = get_train_dataset(args.train_dataset, args.dataset_root) |
| 214 | + |
| 215 | + sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True) |
| 216 | + train_loader = torch.utils.data.DataLoader( |
| 217 | + train_dataset, |
| 218 | + sampler=sampler, |
| 219 | + batch_size=args.batch_size, |
| 220 | + pin_memory=True, |
| 221 | + num_workers=args.num_workers, |
| 222 | + ) |
| 223 | + |
| 224 | + optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.adamw_eps) |
| 225 | + |
| 226 | + scheduler = torch.optim.lr_scheduler.OneCycleLR( |
| 227 | + optimizer=optimizer, |
| 228 | + max_lr=args.lr, |
| 229 | + total_steps=args.num_steps + 100, |
| 230 | + pct_start=0.05, |
| 231 | + cycle_momentum=False, |
| 232 | + anneal_strategy="linear", |
| 233 | + ) |
| 234 | + |
| 235 | + logger = utils.MetricLogger() |
| 236 | + |
| 237 | + done = False |
| 238 | + current_epoch = current_step = 0 |
| 239 | + while not done: |
| 240 | + print(f"EPOCH {current_epoch}") |
| 241 | + |
| 242 | + sampler.set_epoch(current_epoch) # needed, otherwise the data loading order would be the same for all epochs |
| 243 | + done, current_step = train_one_epoch( |
| 244 | + model=model, |
| 245 | + optimizer=optimizer, |
| 246 | + scheduler=scheduler, |
| 247 | + train_loader=train_loader, |
| 248 | + logger=logger, |
| 249 | + current_step=current_step, |
| 250 | + args=args, |
| 251 | + ) |
| 252 | + |
| 253 | + # Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0 |
| 254 | + print(f"Epoch {current_epoch} done. ", logger) |
| 255 | + |
| 256 | + current_epoch += 1 |
| 257 | + |
| 258 | + if args.rank == 0: |
| 259 | + # TODO: Also save the optimizer and scheduler |
| 260 | + torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}_{current_epoch}.pth") |
| 261 | + torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}.pth") |
| 262 | + |
| 263 | + if current_epoch % args.val_freq == 0 or done: |
| 264 | + validate(model, args) |
| 265 | + model.train() |
| 266 | + if args.freeze_batch_norm: |
| 267 | + utils.freeze_batch_norm(model.module) |
| 268 | + |
| 269 | + |
| 270 | +def get_args_parser(add_help=True): |
| 271 | + parser = argparse.ArgumentParser(add_help=add_help, description="Train or evaluate an optical-flow model.") |
| 272 | + parser.add_argument( |
| 273 | + "--name", |
| 274 | + default="raft", |
| 275 | + type=str, |
| 276 | + help="The name of the experiment - determines the name of the files where weights are saved.", |
| 277 | + ) |
| 278 | + parser.add_argument( |
| 279 | + "--output-dir", default="checkpoints", type=str, help="Output dir where checkpoints will be stored." |
| 280 | + ) |
| 281 | + parser.add_argument( |
| 282 | + "--resume", |
| 283 | + type=str, |
| 284 | + help="A path to previously saved weights. Used to re-start training from, or evaluate a pre-saved model.", |
| 285 | + ) |
| 286 | + |
| 287 | + parser.add_argument("--num-workers", type=int, default=12, help="Number of workers for the data loading part.") |
| 288 | + |
| 289 | + parser.add_argument( |
| 290 | + "--train-dataset", |
| 291 | + type=str, |
| 292 | + help="The dataset to use for training. If not passed, only validation is performed (and you probably want to pass --resume).", |
| 293 | + ) |
| 294 | + parser.add_argument("--val-dataset", type=str, nargs="+", help="The dataset(s) to use for validation.") |
| 295 | + parser.add_argument("--val-freq", type=int, default=2, help="Validate every X epochs") |
| 296 | + # TODO: eventually, it might be preferable to support epochs instead of num_steps. |
| 297 | + # Keeping it this way for now to reproduce results more easily. |
| 298 | + parser.add_argument("--num-steps", type=int, default=100000, help="The total number of steps (updates) to train.") |
| 299 | + parser.add_argument("--batch-size", type=int, default=6) |
| 300 | + |
| 301 | + parser.add_argument("--lr", type=float, default=0.00002, help="Learning rate for AdamW optimizer") |
| 302 | + parser.add_argument("--weight-decay", type=float, default=0.00005, help="Weight decay for AdamW optimizer") |
| 303 | + parser.add_argument("--adamw-eps", type=float, default=1e-8, help="eps value for AdamW optimizer") |
| 304 | + |
| 305 | + parser.add_argument( |
| 306 | + "--freeze-batch-norm", action="store_true", help="Set BatchNorm modules of the model in eval mode." |
| 307 | + ) |
| 308 | + |
| 309 | + parser.add_argument("--small", action="store_true", help="Use the 'small' RAFT architecture.") |
| 310 | + |
| 311 | + parser.add_argument( |
| 312 | + "--num_flow_updates", |
| 313 | + type=int, |
| 314 | + default=12, |
| 315 | + help="number of updates (or 'iters') in the update operator of the model.", |
| 316 | + ) |
| 317 | + |
| 318 | + parser.add_argument("--gamma", type=float, default=0.8, help="exponential weighting for loss. Must be < 1.") |
| 319 | + |
| 320 | + parser.add_argument("--dist-url", default="env://", help="URL used to set up distributed training") |
| 321 | + |
| 322 | + parser.add_argument( |
| 323 | + "--dataset-root", |
| 324 | + help="Root folder where the datasets are stored. Will be passed as the 'root' parameter of the datasets.", |
| 325 | + required=True, |
| 326 | + ) |
| 327 | + |
| 328 | + return parser |
| 329 | + |
| 330 | + |
| 331 | +if __name__ == "__main__": |
| 332 | + args = get_args_parser().parse_args() |
| 333 | + Path(args.output_dir).mkdir(exist_ok=True) |
| 334 | + main(args) |
0 commit comments