-
Notifications
You must be signed in to change notification settings - Fork 6k
Attn added kv processor torch 2.0 block #3023
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -255,11 +255,15 @@ def batch_to_head_dim(self, tensor): | |
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) | ||
return tensor | ||
|
||
def head_to_batch_dim(self, tensor): | ||
def head_to_batch_dim(self, tensor, out_dim=3): | ||
head_size = self.heads | ||
batch_size, seq_len, dim = tensor.shape | ||
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) | ||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) | ||
tensor = tensor.permute(0, 2, 1, 3) | ||
|
||
if out_dim == 3: | ||
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) | ||
|
||
return tensor | ||
|
||
def get_attention_scores(self, query, key, attention_mask=None): | ||
|
@@ -293,7 +297,7 @@ def get_attention_scores(self, query, key, attention_mask=None): | |
|
||
return attention_probs | ||
|
||
def prepare_attention_mask(self, attention_mask, target_length, batch_size=None): | ||
def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3): | ||
if batch_size is None: | ||
deprecate( | ||
"batch_size=None", | ||
|
@@ -320,8 +324,13 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None) | |
else: | ||
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) | ||
|
||
if attention_mask.shape[0] < batch_size * head_size: | ||
attention_mask = attention_mask.repeat_interleave(head_size, dim=0) | ||
if out_dim == 3: | ||
if attention_mask.shape[0] < batch_size * head_size: | ||
attention_mask = attention_mask.repeat_interleave(head_size, dim=0) | ||
elif out_dim == 4: | ||
attention_mask = attention_mask.unsqueeze(1) | ||
attention_mask = attention_mask.repeat_interleave(head_size, dim=1) | ||
|
||
return attention_mask | ||
|
||
def norm_encoder_hidden_states(self, encoder_hidden_states): | ||
|
@@ -499,6 +508,64 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a | |
return hidden_states | ||
|
||
|
||
class AttnAddedKVProcessor2_0: | ||
def __init__(self): | ||
if not hasattr(F, "scaled_dot_product_attention"): | ||
raise ImportError( | ||
"AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." | ||
) | ||
|
||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): | ||
residual = hidden_states | ||
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) | ||
batch_size, sequence_length, _ = hidden_states.shape | ||
|
||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4) | ||
|
||
if encoder_hidden_states is None: | ||
encoder_hidden_states = hidden_states | ||
elif attn.norm_cross: | ||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | ||
|
||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | ||
|
||
query = attn.to_q(hidden_states) | ||
query = attn.head_to_batch_dim(query, out_dim=4) | ||
|
||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) | ||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) | ||
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4) | ||
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4) | ||
|
||
if not attn.only_cross_attention: | ||
key = attn.to_k(hidden_states) | ||
value = attn.to_v(hidden_states) | ||
key = attn.head_to_batch_dim(key, out_dim=4) | ||
value = attn.head_to_batch_dim(value, out_dim=4) | ||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) | ||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) | ||
else: | ||
key = encoder_hidden_states_key_proj | ||
value = encoder_hidden_states_value_proj | ||
|
||
# the output of sdp = (batch, num_heads, seq_len, head_dim) | ||
# TODO: add support for attn.scale when we move to Torch 2.1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Out of curiosity. Which scale is this one? |
||
hidden_states = F.scaled_dot_product_attention( | ||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | ||
) | ||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1]) | ||
|
||
# linear proj | ||
hidden_states = attn.to_out[0](hidden_states) | ||
# dropout | ||
hidden_states = attn.to_out[1](hidden_states) | ||
|
||
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) | ||
hidden_states = hidden_states + residual | ||
|
||
return hidden_states | ||
|
||
|
||
class XFormersAttnProcessor: | ||
def __init__(self, attention_op: Optional[Callable] = None): | ||
self.attention_op = attention_op | ||
|
@@ -764,6 +831,7 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, | |
SlicedAttnProcessor, | ||
AttnAddedKVProcessor, | ||
SlicedAttnAddedKVProcessor, | ||
AttnAddedKVProcessor2_0, | ||
LoRAAttnProcessor, | ||
LoRAXFormersAttnProcessor, | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -421,7 +421,12 @@ class DummyScheduler: | |
def test_attention_slicing_forward_pass(self): | ||
test_max_difference = torch_device == "cpu" | ||
|
||
self._test_attention_slicing_forward_pass(test_max_difference=test_max_difference) | ||
# Check is relaxed because there is not a torch 2.0 sliced attention added kv processor | ||
expected_max_diff = 1e-2 | ||
|
||
self._test_attention_slicing_forward_pass( | ||
test_max_difference=test_max_difference, expected_max_diff=expected_max_diff | ||
) | ||
Comment on lines
+424
to
+429
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be conditioned on the torch version being used? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry should have noted, intention is to follow up and add the 2.0 sliced attention added kv processor |
||
|
||
# Overriding PipelineTesterMixin::test_inference_batch_single_identical | ||
# because UnCLIP undeterminism requires a looser check. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten when using the torch built in attention and putting heads as the second dim, we need to make the attention mask also put the heads in the second dim. I'm not sure what the equivalent check for
attention_mask.shape[0] < batch_size * head_size
is. If we assume the input attention mask is always just the same batch size of the inputs, we don't have to do the check and I think this works. My understanding is that's what the original code was doing anyway since it just repeats by the head size regardless