diff --git a/docs/source/en/training/lora.mdx b/docs/source/en/training/lora.mdx index 748d99d5020d..484b08ce950a 100644 --- a/docs/source/en/training/lora.mdx +++ b/docs/source/en/training/lora.mdx @@ -272,4 +272,75 @@ Note that the use of [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] is * LoRA parameters that have separate identifiers for the UNet and the text encoder such as: [`"sayakpaul/dreambooth"`](https://huggingface.co/sayakpaul/dreambooth). **Note** that it is possible to provide a local directory path to [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] as well as [`~diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs`]. To know about the supported inputs, -refer to the respective docstrings. \ No newline at end of file +refer to the respective docstrings. + +## Supporting A1111 themed LoRA checkpoints from Diffusers + +To provide seamless interoperability with A1111 to our users, we support loading A1111 formatted +LoRA checkpoints using [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] in a limited capacity. +In this section, we explain how to load an A1111 formatted LoRA checkpoint from [CivitAI](https://civitai.com/) +in Diffusers and perform inference with it. + +First, download a checkpoint. We'll use +[this one](https://civitai.com/models/13239/light-and-shadow) for demonstration purposes. + +```bash +wget https://civitai.com/api/download/models/15603 -O light_and_shadow.safetensors +``` + +Next, we initialize a [`~DiffusionPipeline`]: + +```python +import torch + +from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler + +pipeline = StableDiffusionPipeline.from_pretrained( + "gsdf/Counterfeit-V2.5", torch_dtype=torch.float16, safety_checker=None +).to("cuda") +pipeline.scheduler = DPMSolverMultistepScheduler.from_config( + pipeline.scheduler.config, use_karras_sigmas=True +) +``` + +We then load the checkpoint downloaded from CivitAI: + +```python +pipeline.load_lora_weights(".", weight_name="light_and_shadow.safetensors") +``` + + + +If you're loading a checkpoint in the `safetensors` format, please ensure you have `safetensors` installed. + + + +And then it's time for running inference: + +```python +prompt = "masterpiece, best quality, 1girl, at dusk" +negative_prompt = ("(low quality, worst quality:1.4), (bad anatomy), (inaccurate limb:1.2), " + "bad composition, inaccurate eyes, extra digit, fewer digits, (extra arms:1.2), large breasts") + +images = pipeline(prompt=prompt, + negative_prompt=negative_prompt, + width=512, + height=768, + num_inference_steps=15, + num_images_per_prompt=4, + generator=torch.manual_seed(0) +).images +``` + +Below is a comparison between the LoRA and the non-LoRA results: + +![lora_non_lora](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lora_non_lora_comparison.png) + +You have a similar checkpoint stored on the Hugging Face Hub, you can load it +directly with [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] like so: + +```python +lora_model_id = "sayakpaul/civitai-light-shadow-lora" +lora_filename = "light_and_shadow.safetensors" +pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename) +``` \ No newline at end of file diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 4ff759dcd6d4..4bd7b2ec3cb2 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -58,7 +58,7 @@ SlicedAttnAddedKVProcessor, ) from diffusers.optimization import get_scheduler -from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, check_min_version, is_wandb_available +from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -847,9 +847,9 @@ def main(args): if args.train_text_encoder: text_lora_attn_procs = {} for name, module in text_encoder.named_modules(): - if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): + if name.endswith(TEXT_ENCODER_ATTN_MODULE): text_lora_attn_procs[name] = LoRAAttnProcessor( - hidden_size=module.out_features, cross_attention_dim=None + hidden_size=module.out_proj.out_features, cross_attention_dim=None ) text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) temp_pipeline = DiffusionPipeline.from_pretrained( diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 84e6b4e61f0f..42625270c12e 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -72,8 +72,8 @@ def __init__(self, state_dict: Dict[str, torch.Tensor]): self.mapping = dict(enumerate(state_dict.keys())) self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} - # .processor for unet, .k_proj, ".q_proj", ".v_proj", and ".out_proj" for text encoder - self.split_keys = [".processor", ".k_proj", ".q_proj", ".v_proj", ".out_proj"] + # .processor for unet, .self_attn for text encoder + self.split_keys = [".processor", ".self_attn"] # we add a hook to state_dict() and load_state_dict() so that the # naming fits with `unet.attn_processors` @@ -182,6 +182,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + network_alpha = kwargs.pop("network_alpha", None) if use_safetensors and not is_safetensors_available(): raise ValueError( @@ -287,7 +290,10 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict attn_processor_class = LoRAAttnProcessor attn_processors[key] = attn_processor_class( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + rank=rank, + network_alpha=network_alpha, ) attn_processors[key].load_state_dict(value_dict) elif is_custom_diffusion: @@ -774,6 +780,8 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + This function is experimental and might change in the future. @@ -898,6 +906,11 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di else: state_dict = pretrained_model_name_or_path_or_dict + # Convert kohya-ss Style LoRA attn procs to diffusers attn procs + network_alpha = None + if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()): + state_dict, network_alpha = self._convert_kohya_lora_to_diffusers(state_dict) + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # their prefixes. @@ -909,7 +922,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di unet_lora_state_dict = { k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys } - self.unet.load_attn_procs(unet_lora_state_dict) + self.unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha) # Load the layers corresponding to text encoder and make necessary adjustments. text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)] @@ -918,7 +931,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys } if len(text_encoder_lora_state_dict) > 0: - attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict) + attn_procs_text_encoder = self._load_text_encoder_attn_procs( + text_encoder_lora_state_dict, network_alpha=network_alpha + ) self._modify_text_encoder(attn_procs_text_encoder) # save lora attn procs of text encoder so that it can be easily retrieved @@ -954,14 +969,20 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): module = self.text_encoder.get_submodule(name) # Construct a new function that performs the LoRA merging. We will monkey patch # this forward pass. - lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) + attn_processor_name = ".".join(name.split(".")[:-1]) + lora_layer = getattr(attn_processors[attn_processor_name], self._get_lora_layer_attribute(name)) old_forward = module.forward - def new_forward(x): - return old_forward(x) + lora_layer(x) + # create a new scope that locks in the old_forward, lora_layer value for each new_forward function + # for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060 + def make_new_forward(old_forward, lora_layer): + def new_forward(x): + return old_forward(x) + lora_layer(x) + + return new_forward # Monkey-patch. - module.forward = new_forward + module.forward = make_new_forward(old_forward, lora_layer) def _get_lora_layer_attribute(self, name: str) -> str: if "q_proj" in name: @@ -1048,6 +1069,7 @@ def _load_text_encoder_attn_procs( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + network_alpha = kwargs.pop("network_alpha", None) if use_safetensors and not is_safetensors_available(): raise ValueError( @@ -1125,7 +1147,10 @@ def _load_text_encoder_attn_procs( hidden_size = value_dict["to_k_lora.up.weight"].shape[0] attn_processors[key] = LoRAAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + rank=rank, + network_alpha=network_alpha, ) attn_processors[key].load_state_dict(value_dict) @@ -1219,6 +1244,56 @@ def save_function(weights, filename): save_function(state_dict, os.path.join(save_directory, weight_name)) logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") + def _convert_kohya_lora_to_diffusers(self, state_dict): + unet_state_dict = {} + te_state_dict = {} + network_alpha = None + + for key, value in state_dict.items(): + if "lora_down" in key: + lora_name = key.split(".")[0] + lora_name_up = lora_name + ".lora_up.weight" + lora_name_alpha = lora_name + ".alpha" + if lora_name_alpha in state_dict: + alpha = state_dict[lora_name_alpha].item() + if network_alpha is None: + network_alpha = alpha + elif network_alpha != alpha: + raise ValueError("Network alpha is not consistent") + + if lora_name.startswith("lora_unet_"): + diffusers_name = key.replace("lora_unet_", "").replace("_", ".") + diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") + diffusers_name = diffusers_name.replace("mid.block", "mid_block") + diffusers_name = diffusers_name.replace("up.blocks", "up_blocks") + diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks") + diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") + if "transformer_blocks" in diffusers_name: + if "attn1" in diffusers_name or "attn2" in diffusers_name: + diffusers_name = diffusers_name.replace("attn1", "attn1.processor") + diffusers_name = diffusers_name.replace("attn2", "attn2.processor") + unet_state_dict[diffusers_name] = value + unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] + elif lora_name.startswith("lora_te_"): + diffusers_name = key.replace("lora_te_", "").replace("_", ".") + diffusers_name = diffusers_name.replace("text.model", "text_model") + diffusers_name = diffusers_name.replace("self.attn", "self_attn") + diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") + if "self_attn" in diffusers_name: + te_state_dict[diffusers_name] = value + te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] + + unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()} + te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()} + new_state_dict = {**unet_state_dict, **te_state_dict} + return new_state_dict, network_alpha + class FromCkptMixin: """This helper class allows to directly load .ckpt stable diffusion file_extension diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e39bdc0429c1..61a1faea07f4 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -508,7 +508,7 @@ def __call__( class LoRALinearLayer(nn.Module): - def __init__(self, in_features, out_features, rank=4): + def __init__(self, in_features, out_features, rank=4, network_alpha=None): super().__init__() if rank > min(in_features, out_features): @@ -516,6 +516,10 @@ def __init__(self, in_features, out_features, rank=4): self.down = nn.Linear(in_features, rank, bias=False) self.up = nn.Linear(rank, out_features, bias=False) + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + self.network_alpha = network_alpha + self.rank = rank nn.init.normal_(self.down.weight, std=1 / rank) nn.init.zeros_(self.up.weight) @@ -527,6 +531,9 @@ def forward(self, hidden_states): down_hidden_states = self.down(hidden_states.to(dtype)) up_hidden_states = self.up(down_hidden_states) + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank + return up_hidden_states.to(orig_dtype) @@ -543,17 +550,17 @@ class LoRAAttnProcessor(nn.Module): The dimension of the LoRA update matrices. """ - def __init__(self, hidden_size, cross_attention_dim=None, rank=4): + def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): super().__init__() self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.rank = rank - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) - self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) - self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None @@ -838,19 +845,19 @@ class LoRAAttnAddedKVProcessor(nn.Module): The dimension of the LoRA update matrices. """ - def __init__(self, hidden_size, cross_attention_dim=None, rank=4): + def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): super().__init__() self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.rank = rank - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) - self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) - self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) - self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank) - self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): residual = hidden_states @@ -1157,7 +1164,9 @@ class LoRAXFormersAttnProcessor(nn.Module): operator. """ - def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None): + def __init__( + self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None + ): super().__init__() self.hidden_size = hidden_size @@ -1165,10 +1174,10 @@ def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optio self.rank = rank self.attention_op = attention_op - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) - self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) - self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index cd3a1b8f3dd4..772c36b1177b 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -30,6 +30,7 @@ ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, + TEXT_ENCODER_ATTN_MODULE, TEXT_ENCODER_TARGET_MODULES, WEIGHTS_NAME, ) diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 1134ba6fb656..93d5c8cc42cd 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -31,3 +31,4 @@ HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "out_proj"] +TEXT_ENCODER_ATTN_MODULE = ".self_attn" diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 64e30ba4057d..d04d87e08b7a 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import gc import os import tempfile import unittest @@ -30,7 +31,7 @@ LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) -from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, floats_tensor, torch_device +from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, floats_tensor, torch_device def create_unet_lora_layers(unet: nn.Module): @@ -50,15 +51,35 @@ def create_unet_lora_layers(unet: nn.Module): return lora_attn_procs, unet_lora_layers -def create_text_encoder_lora_layers(text_encoder: nn.Module): +def create_text_encoder_lora_attn_procs(text_encoder: nn.Module): text_lora_attn_procs = {} for name, module in text_encoder.named_modules(): - if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): - text_lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None) + if name.endswith(TEXT_ENCODER_ATTN_MODULE): + text_lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=module.out_proj.out_features, cross_attention_dim=None + ) + return text_lora_attn_procs + + +def create_text_encoder_lora_layers(text_encoder: nn.Module): + text_lora_attn_procs = create_text_encoder_lora_attn_procs(text_encoder) text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) return text_encoder_lora_layers +def set_lora_up_weights(text_lora_attn_procs, randn_weight=False): + for _, attn_proc in text_lora_attn_procs.items(): + # set up.weights + for layer_name, layer_module in attn_proc.named_modules(): + if layer_name.endswith("_lora"): + weight = ( + torch.randn_like(layer_module.up.weight) + if randn_weight + else torch.zeros_like(layer_module.up.weight) + ) + layer_module.up.weight = torch.nn.Parameter(weight) + + class LoraLoaderMixinTests(unittest.TestCase): def get_dummy_components(self): torch.manual_seed(0) @@ -220,6 +241,64 @@ def test_lora_save_load_legacy(self): # Outputs shouldn't match. self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) + # copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb + def get_dummy_tokens(self): + max_seq_length = 77 + + inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0)) + + prepared_inputs = {} + prepared_inputs["input_ids"] = inputs + return prepared_inputs + + def test_text_encoder_lora_monkey_patch(self): + pipeline_components, _ = self.get_dummy_components() + pipe = StableDiffusionPipeline(**pipeline_components) + + dummy_tokens = self.get_dummy_tokens() + + # inference without lora + outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_without_lora.shape == (1, 77, 32) + + # create lora_attn_procs with zeroed out up.weights + text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder) + set_lora_up_weights(text_attn_procs, randn_weight=False) + + # monkey patch + pipe._modify_text_encoder(text_attn_procs) + + # verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor. + del text_attn_procs + gc.collect() + + # inference with lora + outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_with_lora.shape == (1, 77, 32) + + assert torch.allclose( + outputs_without_lora, outputs_with_lora + ), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs" + + # create lora_attn_procs with randn up.weights + text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder) + set_lora_up_weights(text_attn_procs, randn_weight=True) + + # monkey patch + pipe._modify_text_encoder(text_attn_procs) + + # verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor. + del text_attn_procs + gc.collect() + + # inference with lora + outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_with_lora.shape == (1, 77, 32) + + assert not torch.allclose( + outputs_without_lora, outputs_with_lora + ), "lora_up_weight are not zero, so the lora outputs should be different to without lora outputs" + def create_lora_weight_file(self, tmpdirname): _, lora_components = self.get_dummy_components() LoraLoaderMixin.save_lora_weights(