Skip to content

Commit a89a14f

Browse files
authored
[LoRA] Enabling limited LoRA support for text encoder (#2918)
* add: first draft for a better LoRA enabler. * make fix-copies. * feat: backward compatibility. * add: entry to the docs. * add: tests. * fix: docs. * fix: norm group test for UNet3D. * feat: add support for flat dicts. * add depcrcation message instead of warning.
1 parent e607a58 commit a89a14f

File tree

6 files changed

+682
-11
lines changed

6 files changed

+682
-11
lines changed

docs/source/en/api/loaders.mdx

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,11 @@ API to load such adapter neural networks via the [`loaders.py` module](https://g
2828
### UNet2DConditionLoadersMixin
2929

3030
[[autodoc]] loaders.UNet2DConditionLoadersMixin
31+
32+
### TextualInversionLoaderMixin
33+
34+
[[autodoc]] loaders.TextualInversionLoaderMixin
35+
36+
### LoraLoaderMixin
37+
38+
[[autodoc]] loaders.LoraLoaderMixin

src/diffusers/loaders.py

Lines changed: 457 additions & 9 deletions
Large diffs are not rendered by default.

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
2121

2222
from ...configuration_utils import FrozenDict
23-
from ...loaders import TextualInversionLoaderMixin
23+
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
2424
from ...models import AutoencoderKL, UNet2DConditionModel
2525
from ...schedulers import KarrasDiffusionSchedulers
2626
from ...utils import (
@@ -53,7 +53,7 @@
5353
"""
5454

5555

56-
class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
56+
class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
5757
r"""
5858
Pipeline for text-to-image generation using Stable Diffusion.
5959

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
ONNX_EXTERNAL_WEIGHTS_NAME,
3131
ONNX_WEIGHTS_NAME,
3232
SAFETENSORS_WEIGHTS_NAME,
33+
TEXT_ENCODER_TARGET_MODULES,
3334
WEIGHTS_NAME,
3435
)
3536
from .deprecation_utils import deprecate

src/diffusers/utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,4 @@
3030
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
3131
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
3232
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
33+
TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "out_proj"]

tests/test_lora_layers.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# coding=utf-8
2+
# Copyright 2023 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import os
16+
import tempfile
17+
import unittest
18+
19+
import torch
20+
import torch.nn as nn
21+
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
22+
23+
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
24+
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
25+
from diffusers.models.attention_processor import LoRAAttnProcessor
26+
from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, floats_tensor, torch_device
27+
28+
29+
def create_unet_lora_layers(unet: nn.Module):
30+
lora_attn_procs = {}
31+
for name in unet.attn_processors.keys():
32+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
33+
if name.startswith("mid_block"):
34+
hidden_size = unet.config.block_out_channels[-1]
35+
elif name.startswith("up_blocks"):
36+
block_id = int(name[len("up_blocks.")])
37+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
38+
elif name.startswith("down_blocks"):
39+
block_id = int(name[len("down_blocks.")])
40+
hidden_size = unet.config.block_out_channels[block_id]
41+
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
42+
unet_lora_layers = AttnProcsLayers(lora_attn_procs)
43+
return lora_attn_procs, unet_lora_layers
44+
45+
46+
def create_text_encoder_lora_layers(text_encoder: nn.Module):
47+
text_lora_attn_procs = {}
48+
for name, module in text_encoder.named_modules():
49+
if any([x in name for x in TEXT_ENCODER_TARGET_MODULES]):
50+
text_lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None)
51+
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
52+
return text_encoder_lora_layers
53+
54+
55+
class LoraLoaderMixinTests(unittest.TestCase):
56+
def get_dummy_components(self):
57+
torch.manual_seed(0)
58+
unet = UNet2DConditionModel(
59+
block_out_channels=(32, 64),
60+
layers_per_block=2,
61+
sample_size=32,
62+
in_channels=4,
63+
out_channels=4,
64+
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
65+
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
66+
cross_attention_dim=32,
67+
)
68+
scheduler = DDIMScheduler(
69+
beta_start=0.00085,
70+
beta_end=0.012,
71+
beta_schedule="scaled_linear",
72+
clip_sample=False,
73+
set_alpha_to_one=False,
74+
)
75+
torch.manual_seed(0)
76+
vae = AutoencoderKL(
77+
block_out_channels=[32, 64],
78+
in_channels=3,
79+
out_channels=3,
80+
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
81+
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
82+
latent_channels=4,
83+
)
84+
text_encoder_config = CLIPTextConfig(
85+
bos_token_id=0,
86+
eos_token_id=2,
87+
hidden_size=32,
88+
intermediate_size=37,
89+
layer_norm_eps=1e-05,
90+
num_attention_heads=4,
91+
num_hidden_layers=5,
92+
pad_token_id=1,
93+
vocab_size=1000,
94+
)
95+
text_encoder = CLIPTextModel(text_encoder_config)
96+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
97+
98+
unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
99+
text_encoder_lora_layers = create_text_encoder_lora_layers(text_encoder)
100+
101+
pipeline_components = {
102+
"unet": unet,
103+
"scheduler": scheduler,
104+
"vae": vae,
105+
"text_encoder": text_encoder,
106+
"tokenizer": tokenizer,
107+
"safety_checker": None,
108+
"feature_extractor": None,
109+
}
110+
lora_components = {
111+
"unet_lora_layers": unet_lora_layers,
112+
"text_encoder_lora_layers": text_encoder_lora_layers,
113+
"unet_lora_attn_procs": unet_lora_attn_procs,
114+
}
115+
return pipeline_components, lora_components
116+
117+
def get_dummy_inputs(self):
118+
batch_size = 1
119+
sequence_length = 10
120+
num_channels = 4
121+
sizes = (32, 32)
122+
123+
generator = torch.manual_seed(0)
124+
noise = floats_tensor((batch_size, num_channels) + sizes)
125+
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
126+
127+
pipeline_inputs = {
128+
"prompt": "A painting of a squirrel eating a burger",
129+
"generator": generator,
130+
"num_inference_steps": 2,
131+
"guidance_scale": 6.0,
132+
"output_type": "numpy",
133+
}
134+
135+
return noise, input_ids, pipeline_inputs
136+
137+
def test_lora_save_load(self):
138+
pipeline_components, lora_components = self.get_dummy_components()
139+
sd_pipe = StableDiffusionPipeline(**pipeline_components)
140+
sd_pipe = sd_pipe.to(torch_device)
141+
sd_pipe.set_progress_bar_config(disable=None)
142+
143+
noise, input_ids, pipeline_inputs = self.get_dummy_inputs()
144+
145+
original_images = sd_pipe(**pipeline_inputs).images
146+
orig_image_slice = original_images[0, -3:, -3:, -1]
147+
148+
with tempfile.TemporaryDirectory() as tmpdirname:
149+
LoraLoaderMixin.save_lora_weights(
150+
save_directory=tmpdirname,
151+
unet_lora_layers=lora_components["unet_lora_layers"],
152+
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
153+
)
154+
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
155+
sd_pipe.load_lora_weights(tmpdirname)
156+
157+
lora_images = sd_pipe(**pipeline_inputs).images
158+
lora_image_slice = lora_images[0, -3:, -3:, -1]
159+
160+
# Outputs shouldn't match.
161+
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))
162+
163+
def test_lora_save_load_safetensors(self):
164+
pipeline_components, lora_components = self.get_dummy_components()
165+
sd_pipe = StableDiffusionPipeline(**pipeline_components)
166+
sd_pipe = sd_pipe.to(torch_device)
167+
sd_pipe.set_progress_bar_config(disable=None)
168+
169+
noise, input_ids, pipeline_inputs = self.get_dummy_inputs()
170+
171+
original_images = sd_pipe(**pipeline_inputs).images
172+
orig_image_slice = original_images[0, -3:, -3:, -1]
173+
174+
with tempfile.TemporaryDirectory() as tmpdirname:
175+
LoraLoaderMixin.save_lora_weights(
176+
save_directory=tmpdirname,
177+
unet_lora_layers=lora_components["unet_lora_layers"],
178+
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
179+
safe_serialization=True,
180+
)
181+
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
182+
sd_pipe.load_lora_weights(tmpdirname)
183+
184+
lora_images = sd_pipe(**pipeline_inputs).images
185+
lora_image_slice = lora_images[0, -3:, -3:, -1]
186+
187+
# Outputs shouldn't match.
188+
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))
189+
190+
def test_lora_save_load_legacy(self):
191+
pipeline_components, lora_components = self.get_dummy_components()
192+
unet_lora_attn_procs = lora_components["unet_lora_attn_procs"]
193+
sd_pipe = StableDiffusionPipeline(**pipeline_components)
194+
sd_pipe = sd_pipe.to(torch_device)
195+
sd_pipe.set_progress_bar_config(disable=None)
196+
197+
noise, input_ids, pipeline_inputs = self.get_dummy_inputs()
198+
199+
original_images = sd_pipe(**pipeline_inputs).images
200+
orig_image_slice = original_images[0, -3:, -3:, -1]
201+
202+
with tempfile.TemporaryDirectory() as tmpdirname:
203+
unet = sd_pipe.unet
204+
unet.set_attn_processor(unet_lora_attn_procs)
205+
unet.save_attn_procs(tmpdirname)
206+
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
207+
sd_pipe.load_lora_weights(tmpdirname)
208+
209+
lora_images = sd_pipe(**pipeline_inputs).images
210+
lora_image_slice = lora_images[0, -3:, -3:, -1]
211+
212+
# Outputs shouldn't match.
213+
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))

0 commit comments

Comments
 (0)