-
Notifications
You must be signed in to change notification settings - Fork 6k
[Feature] Finetune text encoder in train_text_to_image_lora #3912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
cc @sayakpaul here |
Thanks for adding this pipeline. Could you also share some comparative results with and without text encoder LoRA fine-tuning? Just trying to gauge the effectiveness. |
@sayakpaul only unetwith text encoder |
Thanks for your hard work! Maybe let's wait for #3778 to get merged, as there are some refactoring-related changes that will make this simpler. Okay for you? |
@sayakpaul No problems. |
Gentle ping @sayakpaul - do you think we could merge this? |
@patrickvonplaten Should we fix some codes based on this PR? |
@sayakpaul I updated codes. But it caused following errors.
|
Hi, Could you also try out the solutions provided in that thread to see if the errors persist? Also, what happens if we do this in Torch 2.0 taking advantage of SDPA disabling xformers? |
When disabling xformers,
|
I tried some pattens, but all failed.
|
@sayakpaul pytorch2.0 and |
Thanks for letting me know and for your efforts. Could we maybe try to dive a bit deeper to see what is the issue? For example does the issue persist when not training the text encoder? |
It also persist when training only unet. |
Okay. Can we verify if the version in the main branch works with the following settings?
Just trying to double down the issue. If you have other ideas to try out please let me know. |
PyTorch 2.0 and SDPA works well. |
Thanks for reporting. This is indeed weird as the |
The error raised when I used my branch of the script. The script of |
Oh okay. Let's try to investigate the differences then :-) |
The error solved when changing the place of applying xformers. |
@sayakpaul This PR is ready. |
xformers_version = version.parse(xformers.__version__) | ||
if xformers_version == version.parse("0.0.16"): | ||
logger.warn( | ||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." | ||
) | ||
unet.enable_xformers_memory_efficient_attention() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent work here. Thanks so much for all your experiments and iterations!
Let's also update the README of the example so that readers are aware that text encoder training is supported? We also need to add a test case for for the text encoder training to test_examples.py
. Then this PR is good to be shipped 🚀
Also, could we do a full run with the changes to ensure the results you're getting are as expected?
https://civitai.com/models/25613/classic-anime-expressions You can download images from this page's Params are followings,
|
docs/source/en/training/lora.mdx
Outdated
from huggingface_hub.repocard import RepoCard | ||
from diffusers import StableDiffusionPipeline | ||
import torch | ||
|
||
lora_model_id = "sayakpaul/sd-model-finetuned-lora-t4" | ||
card = RepoCard.load(lora_model_id) | ||
base_model_id = card.data.to_dict()["base_model"] | ||
|
||
pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16) | ||
pipe = pipe.to("cuda") | ||
pipe.load_lora_weights(lora_model_id) | ||
image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use a LoRA you trained with this script? We'd also need to update the example prompt :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks amazing. Thanks for iterating.
Additionally, I'd also include a note about --train_text_encoder
flag in the README here.
@williamberman could you also take a look here? |
Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
if args.enable_xformers_memory_efficient_attention: | ||
if is_xformers_available(): | ||
import xformers | ||
|
||
xformers_version = version.parse(xformers.__version__) | ||
if xformers_version == version.parse("0.0.16"): | ||
logger.warn( | ||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." | ||
) | ||
unet.enable_xformers_memory_efficient_attention() | ||
else: | ||
raise ValueError("xformers is not available. Make sure it is installed correctly") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unless I'm missing something, In the future it'd be helpful to not move a code block unrelated to the PR as it makes the diff harder to read :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#3912 (comment)
#3912 (comment)
By moving the code block, I avoid this error.
images.append( | ||
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] | ||
) | ||
with torch.cuda.amp.autocast(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did something change that required autocast to be added?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It also related to this error.
You can check this thread.
Looks basically good, a few small questions :) |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
What does this PR do?
Fixes #3418
Add train_text_encoder args in train_text_to_image_lora.py. We can finetune text encoder.
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.