Skip to content

Commit 222ff51

Browse files
williambermanw4ffl35
authored andcommitted
Attn added kv processor torch 2.0 block (huggingface#3023)
add AttnAddedKVProcessor2_0 block
1 parent b626495 commit 222ff51

File tree

4 files changed

+109
-12
lines changed

4 files changed

+109
-12
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 73 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -255,11 +255,15 @@ def batch_to_head_dim(self, tensor):
255255
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
256256
return tensor
257257

258-
def head_to_batch_dim(self, tensor):
258+
def head_to_batch_dim(self, tensor, out_dim=3):
259259
head_size = self.heads
260260
batch_size, seq_len, dim = tensor.shape
261261
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
262-
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
262+
tensor = tensor.permute(0, 2, 1, 3)
263+
264+
if out_dim == 3:
265+
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
266+
263267
return tensor
264268

265269
def get_attention_scores(self, query, key, attention_mask=None):
@@ -293,7 +297,7 @@ def get_attention_scores(self, query, key, attention_mask=None):
293297

294298
return attention_probs
295299

296-
def prepare_attention_mask(self, attention_mask, target_length, batch_size=None):
300+
def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
297301
if batch_size is None:
298302
deprecate(
299303
"batch_size=None",
@@ -320,8 +324,13 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None)
320324
else:
321325
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
322326

323-
if attention_mask.shape[0] < batch_size * head_size:
324-
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
327+
if out_dim == 3:
328+
if attention_mask.shape[0] < batch_size * head_size:
329+
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
330+
elif out_dim == 4:
331+
attention_mask = attention_mask.unsqueeze(1)
332+
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
333+
325334
return attention_mask
326335

327336
def norm_encoder_hidden_states(self, encoder_hidden_states):
@@ -499,6 +508,64 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
499508
return hidden_states
500509

501510

511+
class AttnAddedKVProcessor2_0:
512+
def __init__(self):
513+
if not hasattr(F, "scaled_dot_product_attention"):
514+
raise ImportError(
515+
"AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
516+
)
517+
518+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
519+
residual = hidden_states
520+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
521+
batch_size, sequence_length, _ = hidden_states.shape
522+
523+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
524+
525+
if encoder_hidden_states is None:
526+
encoder_hidden_states = hidden_states
527+
elif attn.norm_cross:
528+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
529+
530+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
531+
532+
query = attn.to_q(hidden_states)
533+
query = attn.head_to_batch_dim(query, out_dim=4)
534+
535+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
536+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
537+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
538+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
539+
540+
if not attn.only_cross_attention:
541+
key = attn.to_k(hidden_states)
542+
value = attn.to_v(hidden_states)
543+
key = attn.head_to_batch_dim(key, out_dim=4)
544+
value = attn.head_to_batch_dim(value, out_dim=4)
545+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
546+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
547+
else:
548+
key = encoder_hidden_states_key_proj
549+
value = encoder_hidden_states_value_proj
550+
551+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
552+
# TODO: add support for attn.scale when we move to Torch 2.1
553+
hidden_states = F.scaled_dot_product_attention(
554+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
555+
)
556+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
557+
558+
# linear proj
559+
hidden_states = attn.to_out[0](hidden_states)
560+
# dropout
561+
hidden_states = attn.to_out[1](hidden_states)
562+
563+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
564+
hidden_states = hidden_states + residual
565+
566+
return hidden_states
567+
568+
502569
class XFormersAttnProcessor:
503570
def __init__(self, attention_op: Optional[Callable] = None):
504571
self.attention_op = attention_op
@@ -764,6 +831,7 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
764831
SlicedAttnProcessor,
765832
AttnAddedKVProcessor,
766833
SlicedAttnAddedKVProcessor,
834+
AttnAddedKVProcessor2_0,
767835
LoRAAttnProcessor,
768836
LoRAXFormersAttnProcessor,
769837
]

src/diffusers/models/unet_2d_blocks.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515

1616
import numpy as np
1717
import torch
18+
import torch.nn.functional as F
1819
from torch import nn
1920

2021
from .attention import AdaGroupNorm, AttentionBlock
21-
from .attention_processor import Attention, AttnAddedKVProcessor
22+
from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
2223
from .dual_transformer_2d import DualTransformer2DModel
2324
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
2425
from .transformer_2d import Transformer2DModel
@@ -612,6 +613,10 @@ def __init__(
612613
attentions = []
613614

614615
for _ in range(num_layers):
616+
processor = (
617+
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
618+
)
619+
615620
attentions.append(
616621
Attention(
617622
query_dim=in_channels,
@@ -624,7 +629,7 @@ def __init__(
624629
upcast_softmax=True,
625630
only_cross_attention=only_cross_attention,
626631
cross_attention_norm=cross_attention_norm,
627-
processor=AttnAddedKVProcessor(),
632+
processor=processor,
628633
)
629634
)
630635
resnets.append(
@@ -1396,6 +1401,11 @@ def __init__(
13961401
skip_time_act=skip_time_act,
13971402
)
13981403
)
1404+
1405+
processor = (
1406+
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
1407+
)
1408+
13991409
attentions.append(
14001410
Attention(
14011411
query_dim=out_channels,
@@ -1408,7 +1418,7 @@ def __init__(
14081418
upcast_softmax=True,
14091419
only_cross_attention=only_cross_attention,
14101420
cross_attention_norm=cross_attention_norm,
1411-
processor=AttnAddedKVProcessor(),
1421+
processor=processor,
14121422
)
14131423
)
14141424
self.attentions = nn.ModuleList(attentions)
@@ -2399,6 +2409,11 @@ def __init__(
23992409
skip_time_act=skip_time_act,
24002410
)
24012411
)
2412+
2413+
processor = (
2414+
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
2415+
)
2416+
24022417
attentions.append(
24032418
Attention(
24042419
query_dim=out_channels,
@@ -2411,7 +2426,7 @@ def __init__(
24112426
upcast_softmax=True,
24122427
only_cross_attention=only_cross_attention,
24132428
cross_attention_norm=cross_attention_norm,
2414-
processor=AttnAddedKVProcessor(),
2429+
processor=processor,
24152430
)
24162431
)
24172432
self.attentions = nn.ModuleList(attentions)

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
from ...configuration_utils import ConfigMixin, register_to_config
99
from ...models import ModelMixin
1010
from ...models.attention import Attention
11-
from ...models.attention_processor import AttentionProcessor, AttnAddedKVProcessor, AttnProcessor
11+
from ...models.attention_processor import (
12+
AttentionProcessor,
13+
AttnAddedKVProcessor,
14+
AttnAddedKVProcessor2_0,
15+
AttnProcessor,
16+
)
1217
from ...models.dual_transformer_2d import DualTransformer2DModel
1318
from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
1419
from ...models.transformer_2d import Transformer2DModel
@@ -1545,6 +1550,10 @@ def __init__(
15451550
attentions = []
15461551

15471552
for _ in range(num_layers):
1553+
processor = (
1554+
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
1555+
)
1556+
15481557
attentions.append(
15491558
Attention(
15501559
query_dim=in_channels,
@@ -1557,7 +1566,7 @@ def __init__(
15571566
upcast_softmax=True,
15581567
only_cross_attention=only_cross_attention,
15591568
cross_attention_norm=cross_attention_norm,
1560-
processor=AttnAddedKVProcessor(),
1569+
processor=processor,
15611570
)
15621571
)
15631572
resnets.append(

tests/pipelines/unclip/test_unclip_image_variation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,12 @@ class DummyScheduler:
421421
def test_attention_slicing_forward_pass(self):
422422
test_max_difference = torch_device == "cpu"
423423

424-
self._test_attention_slicing_forward_pass(test_max_difference=test_max_difference)
424+
# Check is relaxed because there is not a torch 2.0 sliced attention added kv processor
425+
expected_max_diff = 1e-2
426+
427+
self._test_attention_slicing_forward_pass(
428+
test_max_difference=test_max_difference, expected_max_diff=expected_max_diff
429+
)
425430

426431
# Overriding PipelineTesterMixin::test_inference_batch_single_identical
427432
# because UnCLIP undeterminism requires a looser check.

0 commit comments

Comments
 (0)