Skip to content

Attn added kv processor torch 2.0 block #3023

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 73 additions & 5 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,15 @@ def batch_to_head_dim(self, tensor):
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor

def head_to_batch_dim(self, tensor):
def head_to_batch_dim(self, tensor, out_dim=3):
head_size = self.heads
batch_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3)

if out_dim == 3:
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)

return tensor

def get_attention_scores(self, query, key, attention_mask=None):
Expand Down Expand Up @@ -293,7 +297,7 @@ def get_attention_scores(self, query, key, attention_mask=None):

return attention_probs

def prepare_attention_mask(self, attention_mask, target_length, batch_size=None):
def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
if batch_size is None:
deprecate(
"batch_size=None",
Expand All @@ -320,8 +324,13 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None)
else:
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)

if attention_mask.shape[0] < batch_size * head_size:
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
if out_dim == 3:
if attention_mask.shape[0] < batch_size * head_size:
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
elif out_dim == 4:
attention_mask = attention_mask.unsqueeze(1)
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)

Comment on lines +327 to +333
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten when using the torch built in attention and putting heads as the second dim, we need to make the attention mask also put the heads in the second dim. I'm not sure what the equivalent check for attention_mask.shape[0] < batch_size * head_size is. If we assume the input attention mask is always just the same batch size of the inputs, we don't have to do the check and I think this works. My understanding is that's what the original code was doing anyway since it just repeats by the head size regardless

return attention_mask

def norm_encoder_hidden_states(self, encoder_hidden_states):
Expand Down Expand Up @@ -499,6 +508,64 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
return hidden_states


class AttnAddedKVProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)

def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape

attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query, out_dim=4)

encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)

if not attn.only_cross_attention:
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key, out_dim=4)
value = attn.head_to_batch_dim(value, out_dim=4)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
else:
key = encoder_hidden_states_key_proj
value = encoder_hidden_states_value_proj

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity. Which scale is this one?

hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual

return hidden_states


class XFormersAttnProcessor:
def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op
Expand Down Expand Up @@ -764,6 +831,7 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
SlicedAttnProcessor,
AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
LoRAAttnProcessor,
LoRAXFormersAttnProcessor,
]
23 changes: 19 additions & 4 deletions src/diffusers/models/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

from .attention import AdaGroupNorm, AttentionBlock
from .attention_processor import Attention, AttnAddedKVProcessor
from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
from .dual_transformer_2d import DualTransformer2DModel
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
from .transformer_2d import Transformer2DModel
Expand Down Expand Up @@ -612,6 +613,10 @@ def __init__(
attentions = []

for _ in range(num_layers):
processor = (
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
)

attentions.append(
Attention(
query_dim=in_channels,
Expand All @@ -624,7 +629,7 @@ def __init__(
upcast_softmax=True,
only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
processor=AttnAddedKVProcessor(),
processor=processor,
)
)
resnets.append(
Expand Down Expand Up @@ -1396,6 +1401,11 @@ def __init__(
skip_time_act=skip_time_act,
)
)

processor = (
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
)

attentions.append(
Attention(
query_dim=out_channels,
Expand All @@ -1408,7 +1418,7 @@ def __init__(
upcast_softmax=True,
only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
processor=AttnAddedKVProcessor(),
processor=processor,
)
)
self.attentions = nn.ModuleList(attentions)
Expand Down Expand Up @@ -2399,6 +2409,11 @@ def __init__(
skip_time_act=skip_time_act,
)
)

processor = (
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
)

attentions.append(
Attention(
query_dim=out_channels,
Expand All @@ -2411,7 +2426,7 @@ def __init__(
upcast_softmax=True,
only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
processor=AttnAddedKVProcessor(),
processor=processor,
)
)
self.attentions = nn.ModuleList(attentions)
Expand Down
13 changes: 11 additions & 2 deletions src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin
from ...models.attention import Attention
from ...models.attention_processor import AttentionProcessor, AttnAddedKVProcessor, AttnProcessor
from ...models.attention_processor import (
AttentionProcessor,
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
AttnProcessor,
)
from ...models.dual_transformer_2d import DualTransformer2DModel
from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from ...models.transformer_2d import Transformer2DModel
Expand Down Expand Up @@ -1545,6 +1550,10 @@ def __init__(
attentions = []

for _ in range(num_layers):
processor = (
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
)

attentions.append(
Attention(
query_dim=in_channels,
Expand All @@ -1557,7 +1566,7 @@ def __init__(
upcast_softmax=True,
only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
processor=AttnAddedKVProcessor(),
processor=processor,
)
)
resnets.append(
Expand Down
7 changes: 6 additions & 1 deletion tests/pipelines/unclip/test_unclip_image_variation.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,12 @@ class DummyScheduler:
def test_attention_slicing_forward_pass(self):
test_max_difference = torch_device == "cpu"

self._test_attention_slicing_forward_pass(test_max_difference=test_max_difference)
# Check is relaxed because there is not a torch 2.0 sliced attention added kv processor
expected_max_diff = 1e-2

self._test_attention_slicing_forward_pass(
test_max_difference=test_max_difference, expected_max_diff=expected_max_diff
)
Comment on lines +424 to +429
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be conditioned on the torch version being used?

Copy link
Contributor Author

@williamberman williamberman Apr 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry should have noted, intention is to follow up and add the 2.0 sliced attention added kv processor


# Overriding PipelineTesterMixin::test_inference_batch_single_identical
# because UnCLIP undeterminism requires a looser check.
Expand Down