diff --git a/docs/source/en/training/lora.mdx b/docs/source/en/training/lora.mdx
index 748d99d5020d..484b08ce950a 100644
--- a/docs/source/en/training/lora.mdx
+++ b/docs/source/en/training/lora.mdx
@@ -272,4 +272,75 @@ Note that the use of [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] is
* LoRA parameters that have separate identifiers for the UNet and the text encoder such as: [`"sayakpaul/dreambooth"`](https://huggingface.co/sayakpaul/dreambooth).
**Note** that it is possible to provide a local directory path to [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] as well as [`~diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs`]. To know about the supported inputs,
-refer to the respective docstrings.
\ No newline at end of file
+refer to the respective docstrings.
+
+## Supporting A1111 themed LoRA checkpoints from Diffusers
+
+To provide seamless interoperability with A1111 to our users, we support loading A1111 formatted
+LoRA checkpoints using [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] in a limited capacity.
+In this section, we explain how to load an A1111 formatted LoRA checkpoint from [CivitAI](https://civitai.com/)
+in Diffusers and perform inference with it.
+
+First, download a checkpoint. We'll use
+[this one](https://civitai.com/models/13239/light-and-shadow) for demonstration purposes.
+
+```bash
+wget https://civitai.com/api/download/models/15603 -O light_and_shadow.safetensors
+```
+
+Next, we initialize a [`~DiffusionPipeline`]:
+
+```python
+import torch
+
+from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
+
+pipeline = StableDiffusionPipeline.from_pretrained(
+ "gsdf/Counterfeit-V2.5", torch_dtype=torch.float16, safety_checker=None
+).to("cuda")
+pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
+ pipeline.scheduler.config, use_karras_sigmas=True
+)
+```
+
+We then load the checkpoint downloaded from CivitAI:
+
+```python
+pipeline.load_lora_weights(".", weight_name="light_and_shadow.safetensors")
+```
+
+
+
+If you're loading a checkpoint in the `safetensors` format, please ensure you have `safetensors` installed.
+
+
+
+And then it's time for running inference:
+
+```python
+prompt = "masterpiece, best quality, 1girl, at dusk"
+negative_prompt = ("(low quality, worst quality:1.4), (bad anatomy), (inaccurate limb:1.2), "
+ "bad composition, inaccurate eyes, extra digit, fewer digits, (extra arms:1.2), large breasts")
+
+images = pipeline(prompt=prompt,
+ negative_prompt=negative_prompt,
+ width=512,
+ height=768,
+ num_inference_steps=15,
+ num_images_per_prompt=4,
+ generator=torch.manual_seed(0)
+).images
+```
+
+Below is a comparison between the LoRA and the non-LoRA results:
+
+
+
+You have a similar checkpoint stored on the Hugging Face Hub, you can load it
+directly with [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] like so:
+
+```python
+lora_model_id = "sayakpaul/civitai-light-shadow-lora"
+lora_filename = "light_and_shadow.safetensors"
+pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename)
+```
\ No newline at end of file
diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py
index 4ff759dcd6d4..4bd7b2ec3cb2 100644
--- a/examples/dreambooth/train_dreambooth_lora.py
+++ b/examples/dreambooth/train_dreambooth_lora.py
@@ -58,7 +58,7 @@
SlicedAttnAddedKVProcessor,
)
from diffusers.optimization import get_scheduler
-from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, check_min_version, is_wandb_available
+from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
@@ -847,9 +847,9 @@ def main(args):
if args.train_text_encoder:
text_lora_attn_procs = {}
for name, module in text_encoder.named_modules():
- if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
+ if name.endswith(TEXT_ENCODER_ATTN_MODULE):
text_lora_attn_procs[name] = LoRAAttnProcessor(
- hidden_size=module.out_features, cross_attention_dim=None
+ hidden_size=module.out_proj.out_features, cross_attention_dim=None
)
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
temp_pipeline = DiffusionPipeline.from_pretrained(
diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py
index 84e6b4e61f0f..42625270c12e 100644
--- a/src/diffusers/loaders.py
+++ b/src/diffusers/loaders.py
@@ -72,8 +72,8 @@ def __init__(self, state_dict: Dict[str, torch.Tensor]):
self.mapping = dict(enumerate(state_dict.keys()))
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
- # .processor for unet, .k_proj, ".q_proj", ".v_proj", and ".out_proj" for text encoder
- self.split_keys = [".processor", ".k_proj", ".q_proj", ".v_proj", ".out_proj"]
+ # .processor for unet, .self_attn for text encoder
+ self.split_keys = [".processor", ".self_attn"]
# we add a hook to state_dict() and load_state_dict() so that the
# naming fits with `unet.attn_processors`
@@ -182,6 +182,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
+ network_alpha = kwargs.pop("network_alpha", None)
if use_safetensors and not is_safetensors_available():
raise ValueError(
@@ -287,7 +290,10 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
attn_processor_class = LoRAAttnProcessor
attn_processors[key] = attn_processor_class(
- hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ rank=rank,
+ network_alpha=network_alpha,
)
attn_processors[key].load_state_dict(value_dict)
elif is_custom_diffusion:
@@ -774,6 +780,8 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
This function is experimental and might change in the future.
@@ -898,6 +906,11 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
else:
state_dict = pretrained_model_name_or_path_or_dict
+ # Convert kohya-ss Style LoRA attn procs to diffusers attn procs
+ network_alpha = None
+ if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()):
+ state_dict, network_alpha = self._convert_kohya_lora_to_diffusers(state_dict)
+
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
@@ -909,7 +922,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
unet_lora_state_dict = {
k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
}
- self.unet.load_attn_procs(unet_lora_state_dict)
+ self.unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha)
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
@@ -918,7 +931,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) > 0:
- attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict)
+ attn_procs_text_encoder = self._load_text_encoder_attn_procs(
+ text_encoder_lora_state_dict, network_alpha=network_alpha
+ )
self._modify_text_encoder(attn_procs_text_encoder)
# save lora attn procs of text encoder so that it can be easily retrieved
@@ -954,14 +969,20 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
module = self.text_encoder.get_submodule(name)
# Construct a new function that performs the LoRA merging. We will monkey patch
# this forward pass.
- lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name))
+ attn_processor_name = ".".join(name.split(".")[:-1])
+ lora_layer = getattr(attn_processors[attn_processor_name], self._get_lora_layer_attribute(name))
old_forward = module.forward
- def new_forward(x):
- return old_forward(x) + lora_layer(x)
+ # create a new scope that locks in the old_forward, lora_layer value for each new_forward function
+ # for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
+ def make_new_forward(old_forward, lora_layer):
+ def new_forward(x):
+ return old_forward(x) + lora_layer(x)
+
+ return new_forward
# Monkey-patch.
- module.forward = new_forward
+ module.forward = make_new_forward(old_forward, lora_layer)
def _get_lora_layer_attribute(self, name: str) -> str:
if "q_proj" in name:
@@ -1048,6 +1069,7 @@ def _load_text_encoder_attn_procs(
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
+ network_alpha = kwargs.pop("network_alpha", None)
if use_safetensors and not is_safetensors_available():
raise ValueError(
@@ -1125,7 +1147,10 @@ def _load_text_encoder_attn_procs(
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
attn_processors[key] = LoRAAttnProcessor(
- hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ rank=rank,
+ network_alpha=network_alpha,
)
attn_processors[key].load_state_dict(value_dict)
@@ -1219,6 +1244,56 @@ def save_function(weights, filename):
save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
+ def _convert_kohya_lora_to_diffusers(self, state_dict):
+ unet_state_dict = {}
+ te_state_dict = {}
+ network_alpha = None
+
+ for key, value in state_dict.items():
+ if "lora_down" in key:
+ lora_name = key.split(".")[0]
+ lora_name_up = lora_name + ".lora_up.weight"
+ lora_name_alpha = lora_name + ".alpha"
+ if lora_name_alpha in state_dict:
+ alpha = state_dict[lora_name_alpha].item()
+ if network_alpha is None:
+ network_alpha = alpha
+ elif network_alpha != alpha:
+ raise ValueError("Network alpha is not consistent")
+
+ if lora_name.startswith("lora_unet_"):
+ diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
+ diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
+ diffusers_name = diffusers_name.replace("mid.block", "mid_block")
+ diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
+ diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
+ diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
+ diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
+ diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
+ diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
+ if "transformer_blocks" in diffusers_name:
+ if "attn1" in diffusers_name or "attn2" in diffusers_name:
+ diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
+ diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
+ unet_state_dict[diffusers_name] = value
+ unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
+ elif lora_name.startswith("lora_te_"):
+ diffusers_name = key.replace("lora_te_", "").replace("_", ".")
+ diffusers_name = diffusers_name.replace("text.model", "text_model")
+ diffusers_name = diffusers_name.replace("self.attn", "self_attn")
+ diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
+ diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
+ diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
+ diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
+ if "self_attn" in diffusers_name:
+ te_state_dict[diffusers_name] = value
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
+
+ unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()}
+ te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()}
+ new_state_dict = {**unet_state_dict, **te_state_dict}
+ return new_state_dict, network_alpha
+
class FromCkptMixin:
"""This helper class allows to directly load .ckpt stable diffusion file_extension
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index e39bdc0429c1..61a1faea07f4 100644
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -508,7 +508,7 @@ def __call__(
class LoRALinearLayer(nn.Module):
- def __init__(self, in_features, out_features, rank=4):
+ def __init__(self, in_features, out_features, rank=4, network_alpha=None):
super().__init__()
if rank > min(in_features, out_features):
@@ -516,6 +516,10 @@ def __init__(self, in_features, out_features, rank=4):
self.down = nn.Linear(in_features, rank, bias=False)
self.up = nn.Linear(rank, out_features, bias=False)
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
+ self.network_alpha = network_alpha
+ self.rank = rank
nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)
@@ -527,6 +531,9 @@ def forward(self, hidden_states):
down_hidden_states = self.down(hidden_states.to(dtype))
up_hidden_states = self.up(down_hidden_states)
+ if self.network_alpha is not None:
+ up_hidden_states *= self.network_alpha / self.rank
+
return up_hidden_states.to(orig_dtype)
@@ -543,17 +550,17 @@ class LoRAAttnProcessor(nn.Module):
The dimension of the LoRA update matrices.
"""
- def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
- self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
- self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
- self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
- self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
def __call__(
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
@@ -838,19 +845,19 @@ class LoRAAttnAddedKVProcessor(nn.Module):
The dimension of the LoRA update matrices.
"""
- def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
- self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
- self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
- self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
- self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
- self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
- self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
+ self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
+ self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
+ self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
+ self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
residual = hidden_states
@@ -1157,7 +1164,9 @@ class LoRAXFormersAttnProcessor(nn.Module):
operator.
"""
- def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None):
+ def __init__(
+ self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None
+ ):
super().__init__()
self.hidden_size = hidden_size
@@ -1165,10 +1174,10 @@ def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optio
self.rank = rank
self.attention_op = attention_op
- self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
- self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
- self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
- self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
def __call__(
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py
index cd3a1b8f3dd4..772c36b1177b 100644
--- a/src/diffusers/utils/__init__.py
+++ b/src/diffusers/utils/__init__.py
@@ -30,6 +30,7 @@
ONNX_EXTERNAL_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
+ TEXT_ENCODER_ATTN_MODULE,
TEXT_ENCODER_TARGET_MODULES,
WEIGHTS_NAME,
)
diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py
index 1134ba6fb656..93d5c8cc42cd 100644
--- a/src/diffusers/utils/constants.py
+++ b/src/diffusers/utils/constants.py
@@ -31,3 +31,4 @@
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "out_proj"]
+TEXT_ENCODER_ATTN_MODULE = ".self_attn"
diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py
index 64e30ba4057d..d04d87e08b7a 100644
--- a/tests/models/test_lora_layers.py
+++ b/tests/models/test_lora_layers.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import gc
import os
import tempfile
import unittest
@@ -30,7 +31,7 @@
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
-from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, floats_tensor, torch_device
+from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, floats_tensor, torch_device
def create_unet_lora_layers(unet: nn.Module):
@@ -50,15 +51,35 @@ def create_unet_lora_layers(unet: nn.Module):
return lora_attn_procs, unet_lora_layers
-def create_text_encoder_lora_layers(text_encoder: nn.Module):
+def create_text_encoder_lora_attn_procs(text_encoder: nn.Module):
text_lora_attn_procs = {}
for name, module in text_encoder.named_modules():
- if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
- text_lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None)
+ if name.endswith(TEXT_ENCODER_ATTN_MODULE):
+ text_lora_attn_procs[name] = LoRAAttnProcessor(
+ hidden_size=module.out_proj.out_features, cross_attention_dim=None
+ )
+ return text_lora_attn_procs
+
+
+def create_text_encoder_lora_layers(text_encoder: nn.Module):
+ text_lora_attn_procs = create_text_encoder_lora_attn_procs(text_encoder)
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
return text_encoder_lora_layers
+def set_lora_up_weights(text_lora_attn_procs, randn_weight=False):
+ for _, attn_proc in text_lora_attn_procs.items():
+ # set up.weights
+ for layer_name, layer_module in attn_proc.named_modules():
+ if layer_name.endswith("_lora"):
+ weight = (
+ torch.randn_like(layer_module.up.weight)
+ if randn_weight
+ else torch.zeros_like(layer_module.up.weight)
+ )
+ layer_module.up.weight = torch.nn.Parameter(weight)
+
+
class LoraLoaderMixinTests(unittest.TestCase):
def get_dummy_components(self):
torch.manual_seed(0)
@@ -220,6 +241,64 @@ def test_lora_save_load_legacy(self):
# Outputs shouldn't match.
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))
+ # copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
+ def get_dummy_tokens(self):
+ max_seq_length = 77
+
+ inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0))
+
+ prepared_inputs = {}
+ prepared_inputs["input_ids"] = inputs
+ return prepared_inputs
+
+ def test_text_encoder_lora_monkey_patch(self):
+ pipeline_components, _ = self.get_dummy_components()
+ pipe = StableDiffusionPipeline(**pipeline_components)
+
+ dummy_tokens = self.get_dummy_tokens()
+
+ # inference without lora
+ outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0]
+ assert outputs_without_lora.shape == (1, 77, 32)
+
+ # create lora_attn_procs with zeroed out up.weights
+ text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder)
+ set_lora_up_weights(text_attn_procs, randn_weight=False)
+
+ # monkey patch
+ pipe._modify_text_encoder(text_attn_procs)
+
+ # verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor.
+ del text_attn_procs
+ gc.collect()
+
+ # inference with lora
+ outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0]
+ assert outputs_with_lora.shape == (1, 77, 32)
+
+ assert torch.allclose(
+ outputs_without_lora, outputs_with_lora
+ ), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs"
+
+ # create lora_attn_procs with randn up.weights
+ text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder)
+ set_lora_up_weights(text_attn_procs, randn_weight=True)
+
+ # monkey patch
+ pipe._modify_text_encoder(text_attn_procs)
+
+ # verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor.
+ del text_attn_procs
+ gc.collect()
+
+ # inference with lora
+ outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0]
+ assert outputs_with_lora.shape == (1, 77, 32)
+
+ assert not torch.allclose(
+ outputs_without_lora, outputs_with_lora
+ ), "lora_up_weight are not zero, so the lora outputs should be different to without lora outputs"
+
def create_lora_weight_file(self, tmpdirname):
_, lora_components = self.get_dummy_components()
LoraLoaderMixin.save_lora_weights(