Skip to content

Add training reference for optical flow models #5027

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 7, 2021

Conversation

NicolasHug
Copy link
Member

@NicolasHug NicolasHug commented Dec 3, 2021

Towards #4644

This PR adds a training reference for optical flow models (mostly just RAFT at the moment), as well utilities for evaluating the model on Sintel (epe, 1px, 3px, 5px) or Kitti (per-image-epe, F1).

Right now, the training script assumes that CUDA is available and that DDP is available as well. It must be run with torchrun, e.g.

torchrun --nproc_per_node 8 --nnodes 1 references/optical_flow/train.py --batch-size 10 --train-dataset chairs --val-dataset kitti sintel

Our custom run_with_submitit.py script is also partially supported (but it's not as useful anyway, because the training procedure involves training on more than one dataset).

I marked a few TODO comments, leaving them as future potential improvements.

CC @fmassa @datumbox @haooooooqi

cc @datumbox

@NicolasHug NicolasHug added module: reference scripts other if you have no clue or if you will manually handle the PR in the release notes labels Dec 3, 2021
@facebook-github-bot
Copy link

facebook-github-bot commented Dec 3, 2021

💊 CI failures summary and remediations

As of commit 2ef1af5 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@NicolasHug NicolasHug mentioned this pull request Dec 3, 2021
12 tasks


def sequence_loss(flow_preds, flow_gt, valid_flow_mask, gamma=0.8, max_flow=400):
"""Loss function defined over sequence of flow predictions"""
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loss function is very RAFT-specific, because it assumes the model outputs a series of predictions, instead of a single predicted flow.

import torch.nn.functional as F


class SmoothedValue:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This, and the MetricLogger below have been copy/pasted from the classification references. I only made some very minor changes like setting some defaults, probably not worth reviewing.

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

I've made a few comments, all of which can be addressed in follow-up PRs.

In particular, I think we can split a bit more the loss so that there are parts that can be re-used across train / val.

Comment on lines +252 to +253
torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}_{current_epoch}.pth")
torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}.pth")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also save the optimizer and the scheduler so that we can resume training? This is what we do in the other reference scripts

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should definitely do this. I think it's worth refactoring the script to have same functionality and structure as other reference scripts. Moreover we will need to link the ref scripts with the model prototype and add the --weights feature switch.

@NicolasHug do you mind creating an issue for all the above so that we dont forget?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened #5056

Comment on lines +32 to +34
# As future improvement, we could probably be using a distributed sampler here
# The distribution is S(.71), T(.135), K(.135), H(.02)
return 100 * sintel + 200 * kitti + 5 * hd1k + things_clean
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok with me. So you added support for __mul__ in those datasets?

@NicolasHug NicolasHug merged commit 4dd8b5c into pytorch:main Dec 7, 2021
facebook-github-bot pushed a commit that referenced this pull request Dec 9, 2021
Reviewed By: NicolasHug

Differential Revision: D32950938

fbshipit-source-id: 0f271d45026c821c109493d9aa7f404b5373012d
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/default cla signed module: reference scripts other if you have no clue or if you will manually handle the PR in the release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants