-
Notifications
You must be signed in to change notification settings - Fork 6k
[SDXL DreamBooth LoRA] add support for text encoder fine-tuning #4097
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…nto feat/sdxl-dreambooth-returns
The documentation is not available anymore as the PR was closed or merged. |
@patrickvonplaten @williamberman I think I have addressed all your comments:
I would suggest taking another deeper look. |
@@ -809,3 +810,66 @@ def __call__( | |||
return (image,) | |||
|
|||
return StableDiffusionXLPipelineOutput(images=image) | |||
|
|||
# Overrride to properly handle the loading and unloading of the additional text encoder. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok for me!
# needed for the SD XL UNet to operate. | ||
def compute_embeddings(prompt, text_encoders, tokenizers): | ||
def compute_time_ids(): | ||
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids | ||
original_size = (args.resolution, args.resolution) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
original_size = (args.resolution, args.resolution) | |
original_size = (args.resolution, args.resolution) |
This should ideally be the original size of the passed image (before resizing), but ok to leave as is for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
Co-authored-by: Patrick von Platen <[email protected]>
@williamberman ok for you? |
def save_lora_weights( | ||
self, | ||
save_directory: Union[str, os.PathLike], | ||
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, | ||
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, | ||
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, | ||
is_main_process: bool = True, | ||
weight_name: str = None, | ||
save_function: Callable = None, | ||
safe_serialization: bool = False, | ||
): | ||
state_dict = {} | ||
|
||
def pack_weights(layers, prefix): | ||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers | ||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} | ||
return layers_state_dict | ||
|
||
state_dict.update(pack_weights(unet_lora_layers, "unet")) | ||
|
||
if text_encoder_lora_layers and text_encoder_2_lora_layers: | ||
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) | ||
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
<3
class DreamBoothDataset(Dataset): | ||
""" | ||
A dataset to prepare the instance and class images with the prompts for fine-tuning the model. | ||
It pre-processes the images and the tokenizes prompts. | ||
It pre-processes the images. | ||
""" | ||
|
||
def __init__( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
c'est magnifique
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perfect, lgtm!
Thanks all for your suggestions. |
…ingface#4097) * Allow low precision sd xl * finish * finish * feat: initial draft for supporting text encoder lora finetuning for SDXL DreamBooth * fix: variable assignments. * add: autocast block. * add debugging * vae dtype hell * fix: vae dtype hell. * fix: vae dtype hell 3. * clean up * lora text encoder loader. * fix: unwrapping models. * add: tests. * docs. * handle unexpected keys. * fix vae dtype in the final inference. * fix scope problem. * fix: save_model_card args. * initialize: prefix to None. * fix: dtype issues. * apply gixes. * debgging. * debugging * debugging * debugging * debugging * debugging * add: fast tests. * pre-tokenize. * address: will's comments. * fix: loader and tests. * fix: dataloader. * simplify dataloader. * length. * simplification. * make style && make quality * simplify state_dict munging * fix: tests. * fix: state_dict packing. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> --------- Co-authored-by: Patrick von Platen <[email protected]>
…ingface#4097) * Allow low precision sd xl * finish * finish * feat: initial draft for supporting text encoder lora finetuning for SDXL DreamBooth * fix: variable assignments. * add: autocast block. * add debugging * vae dtype hell * fix: vae dtype hell. * fix: vae dtype hell 3. * clean up * lora text encoder loader. * fix: unwrapping models. * add: tests. * docs. * handle unexpected keys. * fix vae dtype in the final inference. * fix scope problem. * fix: save_model_card args. * initialize: prefix to None. * fix: dtype issues. * apply gixes. * debgging. * debugging * debugging * debugging * debugging * debugging * add: fast tests. * pre-tokenize. * address: will's comments. * fix: loader and tests. * fix: dataloader. * simplify dataloader. * length. * simplification. * make style && make quality * simplify state_dict munging * fix: tests. * fix: state_dict packing. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> --------- Co-authored-by: Patrick von Platen <[email protected]>
…ingface#4097) * Allow low precision sd xl * finish * finish * feat: initial draft for supporting text encoder lora finetuning for SDXL DreamBooth * fix: variable assignments. * add: autocast block. * add debugging * vae dtype hell * fix: vae dtype hell. * fix: vae dtype hell 3. * clean up * lora text encoder loader. * fix: unwrapping models. * add: tests. * docs. * handle unexpected keys. * fix vae dtype in the final inference. * fix scope problem. * fix: save_model_card args. * initialize: prefix to None. * fix: dtype issues. * apply gixes. * debgging. * debugging * debugging * debugging * debugging * debugging * add: fast tests. * pre-tokenize. * address: will's comments. * fix: loader and tests. * fix: dataloader. * simplify dataloader. * length. * simplification. * make style && make quality * simplify state_dict munging * fix: tests. * fix: state_dict packing. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> --------- Co-authored-by: Patrick von Platen <[email protected]>
…ingface#4097) * Allow low precision sd xl * finish * finish * feat: initial draft for supporting text encoder lora finetuning for SDXL DreamBooth * fix: variable assignments. * add: autocast block. * add debugging * vae dtype hell * fix: vae dtype hell. * fix: vae dtype hell 3. * clean up * lora text encoder loader. * fix: unwrapping models. * add: tests. * docs. * handle unexpected keys. * fix vae dtype in the final inference. * fix scope problem. * fix: save_model_card args. * initialize: prefix to None. * fix: dtype issues. * apply gixes. * debgging. * debugging * debugging * debugging * debugging * debugging * add: fast tests. * pre-tokenize. * address: will's comments. * fix: loader and tests. * fix: dataloader. * simplify dataloader. * length. * simplification. * make style && make quality * simplify state_dict munging * fix: tests. * fix: state_dict packing. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> --------- Co-authored-by: Patrick von Platen <[email protected]>
…ingface#4097) * Allow low precision sd xl * finish * finish * feat: initial draft for supporting text encoder lora finetuning for SDXL DreamBooth * fix: variable assignments. * add: autocast block. * add debugging * vae dtype hell * fix: vae dtype hell. * fix: vae dtype hell 3. * clean up * lora text encoder loader. * fix: unwrapping models. * add: tests. * docs. * handle unexpected keys. * fix vae dtype in the final inference. * fix scope problem. * fix: save_model_card args. * initialize: prefix to None. * fix: dtype issues. * apply gixes. * debgging. * debugging * debugging * debugging * debugging * debugging * add: fast tests. * pre-tokenize. * address: will's comments. * fix: loader and tests. * fix: dataloader. * simplify dataloader. * length. * simplification. * make style && make quality * simplify state_dict munging * fix: tests. * fix: state_dict packing. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> --------- Co-authored-by: Patrick von Platen <[email protected]>
This PR adds support for text encoder fine-tuning in the DreamBooth LoRA script for SDXL.
Summary of the changes:
To help us maintain sanity, I tested the current training script under three settings:
Artifacts:
Artifacts:
Artifacts: