Skip to content

Commit 8669e83

Browse files
authored
[LoRA] feat: add lora attention processor for pt 2.0. (#3594)
* feat: add lora attention processor for pt 2.0. * explicit context manager for SDPA. * switch to flash attention * make shapes compatible to work optimally with SDPA. * fix: circular import problem. * explicitly specify the flash attention kernel in sdpa * fall back to efficient attention context manager. * remove explicit dispatch. * fix: removed processor. * fix: remove optional from type annotation. * feat: make changes regarding LoRAAttnProcessor2_0. * remove confusing warning. * formatting. * relax tolerance for PT 2.0 * fix: loading message. * remove unnecessary logging. * add: entry to the docs. * add: network_alpha argument. * relax tolerance.
1 parent b45204e commit 8669e83

File tree

6 files changed

+137
-20
lines changed

6 files changed

+137
-20
lines changed

docs/source/en/api/attnprocessor.mdx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ An attention processor is a class for applying different types of attention mech
1111
## LoRAAttnProcessor
1212
[[autodoc]] models.attention_processor.LoRAAttnProcessor
1313

14+
## LoRAAttnProcessor2_0
15+
[[autodoc]] models.attention_processor.LoRAAttnProcessor2_0
16+
1417
## CustomDiffusionAttnProcessor
1518
[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor
1619

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
AttnAddedKVProcessor2_0,
5656
LoRAAttnAddedKVProcessor,
5757
LoRAAttnProcessor,
58+
LoRAAttnProcessor2_0,
5859
SlicedAttnAddedKVProcessor,
5960
)
6061
from diffusers.optimization import get_scheduler
@@ -844,8 +845,9 @@ def main(args):
844845
if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
845846
lora_attn_processor_class = LoRAAttnAddedKVProcessor
846847
else:
847-
lora_attn_processor_class = LoRAAttnProcessor
848-
848+
lora_attn_processor_class = (
849+
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
850+
)
849851
unet_lora_attn_procs[name] = lora_attn_processor_class(
850852
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
851853
)

src/diffusers/loaders.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Callable, Dict, List, Optional, Union
1919

2020
import torch
21+
import torch.nn.functional as F
2122
from huggingface_hub import hf_hub_download
2223

2324
from .models.attention_processor import (
@@ -27,6 +28,7 @@
2728
CustomDiffusionXFormersAttnProcessor,
2829
LoRAAttnAddedKVProcessor,
2930
LoRAAttnProcessor,
31+
LoRAAttnProcessor2_0,
3032
LoRAXFormersAttnProcessor,
3133
SlicedAttnAddedKVProcessor,
3234
XFormersAttnProcessor,
@@ -287,7 +289,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
287289
if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)):
288290
attn_processor_class = LoRAXFormersAttnProcessor
289291
else:
290-
attn_processor_class = LoRAAttnProcessor
292+
attn_processor_class = (
293+
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
294+
)
291295

292296
attn_processors[key] = attn_processor_class(
293297
hidden_size=hidden_size,
@@ -927,11 +931,11 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
927931

928932
# Load the layers corresponding to text encoder and make necessary adjustments.
929933
text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
930-
logger.info(f"Loading {self.text_encoder_name}.")
931934
text_encoder_lora_state_dict = {
932935
k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
933936
}
934937
if len(text_encoder_lora_state_dict) > 0:
938+
logger.info(f"Loading {self.text_encoder_name}.")
935939
attn_procs_text_encoder = self._load_text_encoder_attn_procs(
936940
text_encoder_lora_state_dict, network_alpha=network_alpha
937941
)

src/diffusers/models/attention_processor.py

Lines changed: 106 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import warnings
1514
from typing import Callable, Optional, Union
1615

1716
import torch
@@ -166,7 +165,8 @@ def set_use_memory_efficient_attention_xformers(
166165
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
167166
):
168167
is_lora = hasattr(self, "processor") and isinstance(
169-
self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor)
168+
self.processor,
169+
(LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor),
170170
)
171171
is_custom_diffusion = hasattr(self, "processor") and isinstance(
172172
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
@@ -200,14 +200,6 @@ def set_use_memory_efficient_attention_xformers(
200200
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
201201
" only available for GPU "
202202
)
203-
elif hasattr(F, "scaled_dot_product_attention") and self.scale_qk:
204-
warnings.warn(
205-
"You have specified using flash attention using xFormers but you have PyTorch 2.0 already installed. "
206-
"We will default to PyTorch's native efficient flash attention implementation (`F.scaled_dot_product_attention`) "
207-
"introduced in PyTorch 2.0. In case you are using LoRA or Custom Diffusion, we will fall "
208-
"back to their respective attention processors i.e., we will NOT use the PyTorch 2.0 "
209-
"native efficient flash attention."
210-
)
211203
else:
212204
try:
213205
# Make sure we can run the memory efficient attention
@@ -220,6 +212,8 @@ def set_use_memory_efficient_attention_xformers(
220212
raise e
221213

222214
if is_lora:
215+
# TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
216+
# variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
223217
processor = LoRAXFormersAttnProcessor(
224218
hidden_size=self.processor.hidden_size,
225219
cross_attention_dim=self.processor.cross_attention_dim,
@@ -252,7 +246,10 @@ def set_use_memory_efficient_attention_xformers(
252246
processor = XFormersAttnProcessor(attention_op=attention_op)
253247
else:
254248
if is_lora:
255-
processor = LoRAAttnProcessor(
249+
attn_processor_class = (
250+
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
251+
)
252+
processor = attn_processor_class(
256253
hidden_size=self.processor.hidden_size,
257254
cross_attention_dim=self.processor.cross_attention_dim,
258255
rank=self.processor.rank,
@@ -548,6 +545,8 @@ class LoRAAttnProcessor(nn.Module):
548545
The number of channels in the `encoder_hidden_states`.
549546
rank (`int`, defaults to 4):
550547
The dimension of the LoRA update matrices.
548+
network_alpha (`int`, *optional*):
549+
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
551550
"""
552551

553552
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
@@ -843,6 +842,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
843842
The number of channels in the `encoder_hidden_states`.
844843
rank (`int`, defaults to 4):
845844
The dimension of the LoRA update matrices.
845+
846846
"""
847847

848848
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
@@ -1162,6 +1162,9 @@ class LoRAXFormersAttnProcessor(nn.Module):
11621162
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
11631163
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
11641164
operator.
1165+
network_alpha (`int`, *optional*):
1166+
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1167+
11651168
"""
11661169

11671170
def __init__(
@@ -1236,6 +1239,97 @@ def __call__(
12361239
return hidden_states
12371240

12381241

1242+
class LoRAAttnProcessor2_0(nn.Module):
1243+
r"""
1244+
Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product
1245+
attention.
1246+
1247+
Args:
1248+
hidden_size (`int`):
1249+
The hidden size of the attention layer.
1250+
cross_attention_dim (`int`, *optional*):
1251+
The number of channels in the `encoder_hidden_states`.
1252+
rank (`int`, defaults to 4):
1253+
The dimension of the LoRA update matrices.
1254+
network_alpha (`int`, *optional*):
1255+
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1256+
"""
1257+
1258+
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
1259+
super().__init__()
1260+
if not hasattr(F, "scaled_dot_product_attention"):
1261+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1262+
1263+
self.hidden_size = hidden_size
1264+
self.cross_attention_dim = cross_attention_dim
1265+
self.rank = rank
1266+
1267+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1268+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1269+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1270+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1271+
1272+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
1273+
residual = hidden_states
1274+
1275+
input_ndim = hidden_states.ndim
1276+
1277+
if input_ndim == 4:
1278+
batch_size, channel, height, width = hidden_states.shape
1279+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1280+
1281+
batch_size, sequence_length, _ = (
1282+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1283+
)
1284+
inner_dim = hidden_states.shape[-1]
1285+
1286+
if attention_mask is not None:
1287+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1288+
# scaled_dot_product_attention expects attention_mask shape to be
1289+
# (batch, heads, source_length, target_length)
1290+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1291+
1292+
if attn.group_norm is not None:
1293+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1294+
1295+
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
1296+
1297+
if encoder_hidden_states is None:
1298+
encoder_hidden_states = hidden_states
1299+
elif attn.norm_cross:
1300+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1301+
1302+
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
1303+
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
1304+
1305+
head_dim = inner_dim // attn.heads
1306+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1307+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1308+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1309+
1310+
# TODO: add support for attn.scale when we move to Torch 2.1
1311+
hidden_states = F.scaled_dot_product_attention(
1312+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1313+
)
1314+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1315+
hidden_states = hidden_states.to(query.dtype)
1316+
1317+
# linear proj
1318+
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
1319+
# dropout
1320+
hidden_states = attn.to_out[1](hidden_states)
1321+
1322+
if input_ndim == 4:
1323+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1324+
1325+
if attn.residual_connection:
1326+
hidden_states = hidden_states + residual
1327+
1328+
hidden_states = hidden_states / attn.rescale_output_factor
1329+
1330+
return hidden_states
1331+
1332+
12391333
class CustomDiffusionXFormersAttnProcessor(nn.Module):
12401334
r"""
12411335
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
@@ -1520,6 +1614,7 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
15201614
XFormersAttnAddedKVProcessor,
15211615
LoRAAttnProcessor,
15221616
LoRAXFormersAttnProcessor,
1617+
LoRAAttnProcessor2_0,
15231618
LoRAAttnAddedKVProcessor,
15241619
CustomDiffusionAttnProcessor,
15251620
CustomDiffusionXFormersAttnProcessor,

tests/models/test_lora_layers.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import torch
2121
import torch.nn as nn
22+
import torch.nn.functional as F
2223
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
2324

2425
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
@@ -28,6 +29,7 @@
2829
AttnProcessor,
2930
AttnProcessor2_0,
3031
LoRAAttnProcessor,
32+
LoRAAttnProcessor2_0,
3133
LoRAXFormersAttnProcessor,
3234
XFormersAttnProcessor,
3335
)
@@ -46,16 +48,24 @@ def create_unet_lora_layers(unet: nn.Module):
4648
elif name.startswith("down_blocks"):
4749
block_id = int(name[len("down_blocks.")])
4850
hidden_size = unet.config.block_out_channels[block_id]
49-
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
51+
lora_attn_processor_class = (
52+
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
53+
)
54+
lora_attn_procs[name] = lora_attn_processor_class(
55+
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
56+
)
5057
unet_lora_layers = AttnProcsLayers(lora_attn_procs)
5158
return lora_attn_procs, unet_lora_layers
5259

5360

5461
def create_text_encoder_lora_attn_procs(text_encoder: nn.Module):
5562
text_lora_attn_procs = {}
63+
lora_attn_processor_class = (
64+
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
65+
)
5666
for name, module in text_encoder.named_modules():
5767
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
58-
text_lora_attn_procs[name] = LoRAAttnProcessor(
68+
text_lora_attn_procs[name] = lora_attn_processor_class(
5969
hidden_size=module.out_proj.out_features, cross_attention_dim=None
6070
)
6171
return text_lora_attn_procs
@@ -368,7 +378,10 @@ def test_lora_unet_attn_processors(self):
368378
# check if lora attention processors are used
369379
for _, module in sd_pipe.unet.named_modules():
370380
if isinstance(module, Attention):
371-
self.assertIsInstance(module.processor, LoRAAttnProcessor)
381+
attn_proc_class = (
382+
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
383+
)
384+
self.assertIsInstance(module.processor, attn_proc_class)
372385

373386
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
374387
def test_lora_unet_attn_processors_with_xformers(self):

tests/models/test_models_unet_3d_condition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def test_lora_save_load(self):
261261
with torch.no_grad():
262262
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
263263

264-
assert (sample - new_sample).abs().max() < 1e-4
264+
assert (sample - new_sample).abs().max() < 5e-4
265265

266266
# LoRA and no LoRA should NOT be the same
267267
assert (sample - old_sample).abs().max() > 1e-4
@@ -295,7 +295,7 @@ def test_lora_save_load_safetensors(self):
295295
with torch.no_grad():
296296
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
297297

298-
assert (sample - new_sample).abs().max() < 1e-4
298+
assert (sample - new_sample).abs().max() < 3e-4
299299

300300
# LoRA and no LoRA should NOT be the same
301301
assert (sample - old_sample).abs().max() > 1e-4

0 commit comments

Comments
 (0)