-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[Examples] Add support for Min-SNR weighting strategy for better convergence #2899
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patil-suraj can you take a look here?
Actually, I would suggest not to. I am still gathering evidence to see if the PR is worth merging or even reviewing. After I am done, I will update here. Till then, don't worry about it. |
return fn | ||
|
||
|
||
def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this related to this PR title or does it just add logging?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR could have been without this utility but it was important to have it in the PR because otherwise, it was difficult to validate the effectiveness of the method.
FWIW, I am not a fan of adding unrelated changes in a PR but this seemed important.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok for me!
Would be great to stick to a more imperative coding style which we prefer in libraries like I don't think it's very easy to follow if a function returns another function etc... this is a bit too jaxy/TF-like to me |
@patrickvonplaten the latest changes should have addressed all your concerns. PTAL. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for adding support for this! The loss curve does look smoother than without using min-snr-weighting
. Just left a small comment about v-prediction
.
It would also be cool to add this train_unconditional.py
, it would be easy to verify this there since we train from scratch.
mse_loss_weights = ( | ||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr | ||
) | ||
# We first calculate the original loss. Then we mean over the non-batch dimensions and | ||
# rebalance the sample-wise losses with their respective loss weights. | ||
# Finally, we take the mean of the rebalanced loss. | ||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") | ||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights | ||
loss = loss.mean() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. | ||
# Since we predict the noise instead of x_0, the original formulation is slightly changed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some models (sd2.1 and above) use v-prediction
, does this formulation also work with v-prediction?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will run an experiment and add a PR for that. Thanks for approving. Now, I will:
|
@patil-suraj when you get a moment could you review the changes introduced in 96e7254? All of them are related to documentation. I think then we can merge. @patrickvonplaten maybe you also want to take a look. |
} | ||
|
||
|
||
def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added this method as a part of the PR as well. Handles EMA offload and unload properly to ensure inference is being done with the EMA'd checkpoints.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for adding the doc. Looks good!
|
||
We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence | ||
by rebalancing the loss. In order to use it, one needs to set the `--snr_gamma` argument. The recommended | ||
value when using it is 5.0. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, where is this value proposed? Is there a rule of thumb when choosing a value for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's reported in the paper. A gamma of 5.0 always leads to better results in the experiments presented by the authors in the paper.
|
||
<Tip warning={true}> | ||
|
||
Training with Min-SNR weighting strategy is only supported in PyTorch. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for future PR: Could be cool to add this in jax as well, will be useful for the jax event.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yiyixuxu could you take a look?
Co-authored-by: Suraj Patil <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool! Thanks for iterating
…ergence (huggingface#2899) * improve stable unclip doc. * feat: support for applying min-snr weighting for faster convergence. * add: support for validation logging with wandb * make not a required arg. * fix: arg name. * fix: cli args. * fix: tracker config. * fix: loss calculation. * fix: validation logging. * fix: unwrap call. * fix: validation logging. * fix: internval. * fix: checkpointing push to hub. * fix: https://github.com/huggingface/diffusers/commit/c8a2856c6d5e45577bf4c24dee06b1a4a2f5c050\#commitcomment-106913193 * fix: norm group test for UNet3D. * address PR comments. * remove unneeded code. * add: entry in the readme and docs. * Apply suggestions from code review Co-authored-by: Suraj Patil <[email protected]> --------- Co-authored-by: Suraj Patil <[email protected]>
…ergence (huggingface#2899) * improve stable unclip doc. * feat: support for applying min-snr weighting for faster convergence. * add: support for validation logging with wandb * make not a required arg. * fix: arg name. * fix: cli args. * fix: tracker config. * fix: loss calculation. * fix: validation logging. * fix: unwrap call. * fix: validation logging. * fix: internval. * fix: checkpointing push to hub. * fix: https://github.com/huggingface/diffusers/commit/c8a2856c6d5e45577bf4c24dee06b1a4a2f5c050\#commitcomment-106913193 * fix: norm group test for UNet3D. * address PR comments. * remove unneeded code. * add: entry in the readme and docs. * Apply suggestions from code review Co-authored-by: Suraj Patil <[email protected]> --------- Co-authored-by: Suraj Patil <[email protected]>
…ergence (huggingface#2899) * improve stable unclip doc. * feat: support for applying min-snr weighting for faster convergence. * add: support for validation logging with wandb * make not a required arg. * fix: arg name. * fix: cli args. * fix: tracker config. * fix: loss calculation. * fix: validation logging. * fix: unwrap call. * fix: validation logging. * fix: internval. * fix: checkpointing push to hub. * fix: https://github.com/huggingface/diffusers/commit/c8a2856c6d5e45577bf4c24dee06b1a4a2f5c050\#commitcomment-106913193 * fix: norm group test for UNet3D. * address PR comments. * remove unneeded code. * add: entry in the readme and docs. * Apply suggestions from code review Co-authored-by: Suraj Patil <[email protected]> --------- Co-authored-by: Suraj Patil <[email protected]>
[1] introduces the Min-SNR weighting strategy to rebalance the loss when training the diffusion model for faster convergence. The authors attribute the difficulty in getting diffusion models to converge faster to varying degrees of timesteps in the noise scheduling process. So, they introduce a simple way to balance the losses of the individual samples.
This PR refers to [1] and [2] to incorporate the Min-SNR weighting strategy in the text-to-image fine-tuning script. I believe this rebalancing can be incorporated in other examples too, where we fine-tune the Diffusion model (i.e., the UNet).
My experimentation results are available here: https://wandb.ai/sayakpaul/text2image-finetune-minsnr
Overall, this strategy helps to keep the loss surface less bumpy: https://wandb.ai/sayakpaul/text2image-finetune-minsnr/reports/train_loss-23-04-04-08-49-34---VmlldzozOTY3ODQ2
The number of training samples is definitely a factor that can make the effect of this strategy less pronounced. But, overall, I think it's nice to have as it's directly related to overcoming the training instabilities of diffusion models.
@patil-suraj if we're okay with this change, I will add a section on it in the README and also in https://huggingface.co/docs/diffusers/training/text2image. Let me know.
References
[1] Paper reference: https://arxiv.org/abs/2303.09556
[2] Code: https://github.com/TiankaiHang/Min-SNR-Diffusion-Training