-
Notifications
You must be signed in to change notification settings - Fork 7.1k
RAFT training reference Improvement #5590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RAFT training reference Improvement #5590
Conversation
…so it is similar to other references
💊 CI failures summary and remediationsAs of commit 0e7ab27 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR @YosuaMichael. There are 2 minor issues (see below), but otherwise this looks great!
…/vision into raft-reference-improvement
Update: Support saving of the optimizer and scheduler on the checkpoint. |
Hi @NicolasHug , I decided to put the commit for saving optimizer and scheduler in this PR as well: 09d78d1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @YosuaMichael , we're almost there :) . I made a few comments below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @YosuaMichael, nice work ! There was a minor issue left, which I fixed in 2857e21: when no trainset is specified we want to directly go to evaluate
, without worrying about train_dataset
- the previous code would fail because it's None.
Hey @NicolasHug! You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py |
Summary: * Change optical flow train.py function name from validate to evaluate so it is similar to other references * Add --device as parameter and enable to run in non distributed mode * Format with ufmt * Fix unneccessary param and bug * Enable saving the optimizer and scheduler on the checkpoint * Fix bug when evaluate before resume and save or load model without ddp * Fix case where --train-dataset is None (Note: this ignores all push blocking failures!) Reviewed By: YosuaMichael Differential Revision: D35216768 fbshipit-source-id: 3b575d9f4a51caed920ff402e160a26ff6f3c2d4 Co-authored-by: Nicolas Hug <[email protected]>
Do some of the task on: #5056
validate
toevaluate
Sample script to run on non-distributed mode and on cpu:
To test on CPU, I run on a mock dataset by replacing https://github.com/pytorch/vision/blob/main/torchvision/datasets/_optical_flow.py with https://gist.github.com/YosuaMichael/9c49729243ff9d467ece06ab8641680d.
Note that as of now, if we run on distributed mode using torchrun, then it must use
--device cuda
.