-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Fix various bugs with LoRA Dreambooth and Dreambooth script #3353
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2e27eb3
7a7ffd3
b462b4e
3e81c47
8d52ed7
678184a
d6c4872
d4b6502
22fff9d
9b38e2c
25c937a
4874e49
c3b5f53
97a140b
27c7191
989c40b
280d805
df87aaf
70a15fc
eb4e456
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -834,7 +834,6 @@ def main(args): | |
|
||
unet.set_attn_processor(unet_lora_attn_procs) | ||
unet_lora_layers = AttnProcsLayers(unet.attn_processors) | ||
accelerator.register_for_checkpointing(unet_lora_layers) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We now save LoRA layers in better format |
||
|
||
# The text encoder comes from 🤗 transformers, so we cannot directly modify it. | ||
# So, instead, we monkey-patch the forward calls of its attention-blocks. For this, | ||
|
@@ -853,9 +852,68 @@ def main(args): | |
) | ||
temp_pipeline._modify_text_encoder(text_lora_attn_procs) | ||
text_encoder = temp_pipeline.text_encoder | ||
accelerator.register_for_checkpointing(text_encoder_lora_layers) | ||
del temp_pipeline | ||
|
||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format | ||
def save_model_hook(models, weights, output_dir): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's make sure the lora layers are saved in a more user-friendly way here |
||
# there are only two options here. Either are just the unet attn processor layers | ||
# or there are the unet and text encoder atten layers | ||
unet_lora_layers_to_save = None | ||
text_encoder_lora_layers_to_save = None | ||
|
||
if args.train_text_encoder: | ||
text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys() | ||
unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys() | ||
|
||
for model in models: | ||
state_dict = model.state_dict() | ||
|
||
if ( | ||
text_encoder_lora_layers is not None | ||
and text_encoder_keys is not None | ||
and state_dict.keys() == text_encoder_keys | ||
): | ||
# text encoder | ||
text_encoder_lora_layers_to_save = state_dict | ||
elif state_dict.keys() == unet_keys: | ||
# unet | ||
unet_lora_layers_to_save = state_dict | ||
|
||
# make sure to pop weight so that corresponding model is not saved again | ||
weights.pop() | ||
|
||
LoraLoaderMixin.save_lora_weights( | ||
output_dir, | ||
unet_lora_layers=unet_lora_layers_to_save, | ||
text_encoder_lora_layers=text_encoder_lora_layers_to_save, | ||
) | ||
|
||
def load_model_hook(models, input_dir): | ||
# Note we DON'T pass the unet and text encoder here an purpose | ||
# so that the we don't accidentally override the LoRA layers of | ||
# unet_lora_layers and text_encoder_lora_layers which are stored in `models` | ||
# with new torch.nn.Modules / weights. We simply use the pipeline class as | ||
# an easy way to load the lora checkpoints | ||
temp_pipeline = DiffusionPipeline.from_pretrained( | ||
args.pretrained_model_name_or_path, | ||
revision=args.revision, | ||
torch_dtype=weight_dtype, | ||
) | ||
temp_pipeline.load_lora_weights(input_dir) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Clean and tight 🎸 |
||
|
||
# load lora weights into models | ||
models[0].load_state_dict(AttnProcsLayers(temp_pipeline.unet.attn_processors).state_dict()) | ||
if len(models) > 1: | ||
models[1].load_state_dict(AttnProcsLayers(temp_pipeline.text_encoder_lora_attn_procs).state_dict()) | ||
|
||
# delete temporary pipeline and pop models | ||
del temp_pipeline | ||
for _ in range(len(models)): | ||
models.pop() | ||
|
||
accelerator.register_save_state_pre_hook(save_model_hook) | ||
accelerator.register_load_state_pre_hook(load_model_hook) | ||
|
||
# Enable TF32 for faster training on Ampere GPUs, | ||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices | ||
if args.allow_tf32: | ||
|
@@ -1130,17 +1188,10 @@ def compute_text_embeddings(prompt): | |
progress_bar.update(1) | ||
global_step += 1 | ||
|
||
if global_step % args.checkpointing_steps == 0: | ||
if accelerator.is_main_process: | ||
if accelerator.is_main_process: | ||
if global_step % args.checkpointing_steps == 0: | ||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") | ||
# We combine the text encoder and UNet LoRA parameters with a simple | ||
# custom logic. `accelerator.save_state()` won't know that. So, | ||
# use `LoraLoaderMixin.save_lora_weights()`. | ||
LoraLoaderMixin.save_lora_weights( | ||
save_directory=save_path, | ||
unet_lora_layers=unet_lora_layers, | ||
text_encoder_lora_layers=text_encoder_lora_layers, | ||
) | ||
accelerator.save_state(save_path) | ||
logger.info(f"Saved state to {save_path}") | ||
|
||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} | ||
|
@@ -1217,8 +1268,12 @@ def compute_text_embeddings(prompt): | |
accelerator.wait_for_everyone() | ||
if accelerator.is_main_process: | ||
unet = unet.to(torch.float32) | ||
unet_lora_layers = accelerator.unwrap_model(unet_lora_layers) | ||
|
||
if text_encoder is not None: | ||
text_encoder = text_encoder.to(torch.float32) | ||
text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers) | ||
|
||
LoraLoaderMixin.save_lora_weights( | ||
save_directory=args.output_dir, | ||
unet_lora_layers=unet_lora_layers, | ||
|
@@ -1250,6 +1305,7 @@ def compute_text_embeddings(prompt): | |
pipeline.load_lora_weights(args.output_dir) | ||
|
||
# run inference | ||
images = [] | ||
if args.validation_prompt and args.num_validation_images > 0: | ||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None | ||
images = [ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,6 +70,9 @@ def __init__(self, state_dict: Dict[str, torch.Tensor]): | |
self.mapping = dict(enumerate(state_dict.keys())) | ||
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} | ||
|
||
# .processor for unet, .k_proj, ".q_proj", ".v_proj", and ".out_proj" for text encoder | ||
self.split_keys = [".processor", ".k_proj", ".q_proj", ".v_proj", ".out_proj"] | ||
|
||
# we add a hook to state_dict() and load_state_dict() so that the | ||
# naming fits with `unet.attn_processors` | ||
def map_to(module, state_dict, *args, **kwargs): | ||
|
@@ -81,10 +84,19 @@ def map_to(module, state_dict, *args, **kwargs): | |
|
||
return new_state_dict | ||
|
||
def remap_key(key, state_dict): | ||
for k in self.split_keys: | ||
if k in key: | ||
return key.split(k)[0] + k | ||
|
||
raise ValueError( | ||
f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}." | ||
) | ||
|
||
def map_from(module, state_dict, *args, **kwargs): | ||
all_keys = list(state_dict.keys()) | ||
for key in all_keys: | ||
replace_key = key.split(".processor")[0] + ".processor" | ||
replace_key = remap_key(key, state_dict) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to be able to split this also according to other keys |
||
new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") | ||
state_dict[new_key] = state_dict[key] | ||
del state_dict[key] | ||
|
@@ -898,6 +910,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di | |
attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict) | ||
self._modify_text_encoder(attn_procs_text_encoder) | ||
|
||
# save lora attn procs of text encoder so that it can be easily retrieved | ||
self._text_encoder_lora_attn_procs = attn_procs_text_encoder | ||
|
||
# Otherwise, we're dealing with the old format. This means the `state_dict` should only | ||
# contain the module names of the `unet` as its keys WITHOUT any prefix. | ||
elif not all( | ||
|
@@ -907,6 +922,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di | |
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`." | ||
warnings.warn(warn_message) | ||
|
||
@property | ||
def text_encoder_lora_attn_procs(self): | ||
if hasattr(self, "_text_encoder_lora_attn_procs"): | ||
return self._text_encoder_lora_attn_procs | ||
return | ||
|
||
def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): | ||
r""" | ||
Monkey-patches the forward passes of attention modules of the text encoder. | ||
|
@@ -1110,7 +1131,7 @@ def _load_text_encoder_attn_procs( | |
def save_lora_weights( | ||
self, | ||
save_directory: Union[str, os.PathLike], | ||
unet_lora_layers: Dict[str, torch.nn.Module] = None, | ||
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's allow to also pass the lora layers as low level torch.Tensors state dicts |
||
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, | ||
is_main_process: bool = True, | ||
weight_name: str = None, | ||
|
@@ -1123,13 +1144,14 @@ def save_lora_weights( | |
Arguments: | ||
save_directory (`str` or `os.PathLike`): | ||
Directory to which to save. Will be created if it doesn't exist. | ||
unet_lora_layers (`Dict[str, torch.nn.Module`]): | ||
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): | ||
State dict of the LoRA layers corresponding to the UNet. Specifying this helps to make the | ||
serialization process easier and cleaner. | ||
text_encoder_lora_layers (`Dict[str, torch.nn.Module`]): | ||
serialization process easier and cleaner. Values can be both LoRA torch.nn.Modules layers or torch | ||
weights. | ||
text_encoder_lora_layers (`Dict[str, torch.nn.Module] or `Dict[str, torch.Tensor]`): | ||
State dict of the LoRA layers corresponding to the `text_encoder`. Since the `text_encoder` comes from | ||
`transformers`, we cannot rejig it. That is why we have to explicitly pass the text encoder LoRA state | ||
dict. | ||
dict. Values can be both LoRA torch.nn.Modules layers or torch weights. | ||
is_main_process (`bool`, *optional*, defaults to `True`): | ||
Whether the process calling this is the main process or not. Useful when in distributed training like | ||
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on | ||
|
@@ -1157,15 +1179,22 @@ def save_function(weights, filename): | |
# Create a flat dictionary. | ||
state_dict = {} | ||
if unet_lora_layers is not None: | ||
unet_lora_state_dict = { | ||
f"{self.unet_name}.{module_name}": param | ||
for module_name, param in unet_lora_layers.state_dict().items() | ||
} | ||
weights = ( | ||
unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers | ||
) | ||
|
||
unet_lora_state_dict = {f"{self.unet_name}.{module_name}": param for module_name, param in weights.items()} | ||
state_dict.update(unet_lora_state_dict) | ||
|
||
if text_encoder_lora_layers is not None: | ||
weights = ( | ||
text_encoder_lora_layers.state_dict() | ||
if isinstance(text_encoder_lora_layers, torch.nn.Module) | ||
else text_encoder_lora_layers | ||
) | ||
|
||
text_encoder_lora_state_dict = { | ||
f"{self.text_encoder_name}.{module_name}": param | ||
for module_name, param in text_encoder_lora_layers.state_dict().items() | ||
f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items() | ||
} | ||
state_dict.update(text_encoder_lora_state_dict) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
min version of accelerate is already 0.16 => let's remove this check here