-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add training reference for optical flow models #5027
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
💊 CI failures summary and remediationsAs of commit 2ef1af5 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
references/optical_flow/utils.py
Outdated
|
||
|
||
def sequence_loss(flow_preds, flow_gt, valid_flow_mask, gamma=0.8, max_flow=400): | ||
"""Loss function defined over sequence of flow predictions""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This loss function is very RAFT-specific, because it assumes the model outputs a series of predictions, instead of a single predicted flow.
import torch.nn.functional as F | ||
|
||
|
||
class SmoothedValue: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This, and the MetricLogger
below have been copy/pasted from the classification references. I only made some very minor changes like setting some defaults, probably not worth reviewing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
I've made a few comments, all of which can be addressed in follow-up PRs.
In particular, I think we can split a bit more the loss so that there are parts that can be re-used across train / val.
torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}_{current_epoch}.pth") | ||
torch.save(model.state_dict(), Path(args.output_dir) / f"{args.name}.pth") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also save the optimizer
and the scheduler
so that we can resume training? This is what we do in the other reference scripts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should definitely do this. I think it's worth refactoring the script to have same functionality and structure as other reference scripts. Moreover we will need to link the ref scripts with the model prototype and add the --weights
feature switch.
@NicolasHug do you mind creating an issue for all the above so that we dont forget?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opened #5056
# As future improvement, we could probably be using a distributed sampler here | ||
# The distribution is S(.71), T(.135), K(.135), H(.02) | ||
return 100 * sintel + 200 * kitti + 5 * hd1k + things_clean |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok with me. So you added support for __mul__
in those datasets?
Reviewed By: NicolasHug Differential Revision: D32950938 fbshipit-source-id: 0f271d45026c821c109493d9aa7f404b5373012d
Towards #4644
This PR adds a training reference for optical flow models (mostly just RAFT at the moment), as well utilities for evaluating the model on Sintel (epe, 1px, 3px, 5px) or Kitti (per-image-epe, F1).
Right now, the training script assumes that CUDA is available and that DDP is available as well. It must be run with
torchrun
, e.g.Our custom
run_with_submitit.py
script is also partially supported (but it's not as useful anyway, because the training procedure involves training on more than one dataset).I marked a few TODO comments, leaving them as future potential improvements.
CC @fmassa @datumbox @haooooooqi
cc @datumbox