-
Notifications
You must be signed in to change notification settings - Fork 6k
[Core] add QKV fusion to AuraFlow and PixArt Sigma #8952
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Changes from 6 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
40ad9c8
add fusion support to pixart
sayakpaul 48dfbb6
add to auraflow.
sayakpaul 2ff7a4a
add tests
sayakpaul 26c687c
Merge branch 'main' into qkv-rest
sayakpaul e933bde
Merge branch 'main' into qkv-rest
sayakpaul c3ce92c
Merge branch 'main' into qkv-rest
sayakpaul 7734a3d
Merge branch 'main' into qkv-rest
sayakpaul 62f2af1
apply review feedback.
sayakpaul eca8d2a
Merge branch 'main' into qkv-rest
sayakpaul 33310ef
add back args and kwargs
sayakpaul 0e5036d
style
sayakpaul c1e6a46
Merge branch 'main' into qkv-rest
sayakpaul 544391d
Merge branch 'main' into qkv-rest
sayakpaul 9171f8d
Merge branch 'main' into qkv-rest
sayakpaul 0e61ed2
resolve brutal conflicts.
sayakpaul 632e583
Merge branch 'main' into qkv-rest
sayakpaul 8ed3307
Merge branch 'main' into qkv-rest
sayakpaul File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -219,6 +219,7 @@ def __init__( | |||||
self.to_v = None | ||||||
|
||||||
if self.added_kv_proj_dim is not None: | ||||||
self.added_proj_bias = added_proj_bias | ||||||
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) | ||||||
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) | ||||||
if self.context_pre_only is not None: | ||||||
|
@@ -685,12 +686,15 @@ def fuse_projections(self, fuse=True): | |||||
in_features = concatenated_weights.shape[1] | ||||||
out_features = concatenated_weights.shape[0] | ||||||
|
||||||
self.to_added_qkv = nn.Linear(in_features, out_features, bias=True, device=device, dtype=dtype) | ||||||
self.to_added_qkv.weight.copy_(concatenated_weights) | ||||||
concatenated_bias = torch.cat( | ||||||
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] | ||||||
self.to_added_qkv = nn.Linear( | ||||||
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype | ||||||
) | ||||||
self.to_added_qkv.bias.copy_(concatenated_bias) | ||||||
self.to_added_qkv.weight.copy_(concatenated_weights) | ||||||
if self.added_proj_bias: | ||||||
concatenated_bias = torch.cat( | ||||||
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] | ||||||
) | ||||||
self.to_added_qkv.bias.copy_(concatenated_bias) | ||||||
|
||||||
self.fused_projections = fuse | ||||||
|
||||||
|
@@ -1261,6 +1265,103 @@ def __call__( | |||||
return hidden_states | ||||||
|
||||||
|
||||||
class FusedAuraFlowAttnProcessor2_0: | ||||||
"""Attention processor used typically in processing Aura Flow with fused projections.""" | ||||||
|
||||||
def __init__(self): | ||||||
if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"): | ||||||
raise ImportError( | ||||||
"FusedAuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. " | ||||||
) | ||||||
|
||||||
def __call__( | ||||||
self, | ||||||
attn: Attention, | ||||||
hidden_states: torch.FloatTensor, | ||||||
encoder_hidden_states: torch.FloatTensor = None, | ||||||
*args, | ||||||
**kwargs, | ||||||
Comment on lines
+1295
to
+1296
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we do not need them, no?
Suggested change
|
||||||
) -> torch.FloatTensor: | ||||||
batch_size = hidden_states.shape[0] | ||||||
|
||||||
# `sample` projections. | ||||||
qkv = attn.to_qkv(hidden_states) | ||||||
split_size = qkv.shape[-1] // 3 | ||||||
query, key, value = torch.split(qkv, split_size, dim=-1) | ||||||
|
||||||
# `context` projections. | ||||||
if encoder_hidden_states is not None: | ||||||
encoder_qkv = attn.to_added_qkv(encoder_hidden_states) | ||||||
split_size = encoder_qkv.shape[-1] // 3 | ||||||
( | ||||||
encoder_hidden_states_query_proj, | ||||||
encoder_hidden_states_key_proj, | ||||||
encoder_hidden_states_value_proj, | ||||||
) = torch.split(encoder_qkv, split_size, dim=-1) | ||||||
|
||||||
# Reshape. | ||||||
inner_dim = key.shape[-1] | ||||||
head_dim = inner_dim // attn.heads | ||||||
query = query.view(batch_size, -1, attn.heads, head_dim) | ||||||
key = key.view(batch_size, -1, attn.heads, head_dim) | ||||||
value = value.view(batch_size, -1, attn.heads, head_dim) | ||||||
|
||||||
# Apply QK norm. | ||||||
if attn.norm_q is not None: | ||||||
query = attn.norm_q(query) | ||||||
if attn.norm_k is not None: | ||||||
key = attn.norm_k(key) | ||||||
|
||||||
# Concatenate the projections. | ||||||
if encoder_hidden_states is not None: | ||||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( | ||||||
batch_size, -1, attn.heads, head_dim | ||||||
) | ||||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim) | ||||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( | ||||||
batch_size, -1, attn.heads, head_dim | ||||||
) | ||||||
|
||||||
if attn.norm_added_q is not None: | ||||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) | ||||||
if attn.norm_added_k is not None: | ||||||
encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj) | ||||||
|
||||||
query = torch.cat([encoder_hidden_states_query_proj, query], dim=1) | ||||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) | ||||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) | ||||||
|
||||||
query = query.transpose(1, 2) | ||||||
key = key.transpose(1, 2) | ||||||
value = value.transpose(1, 2) | ||||||
|
||||||
# Attention. | ||||||
hidden_states = F.scaled_dot_product_attention( | ||||||
query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False | ||||||
) | ||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | ||||||
hidden_states = hidden_states.to(query.dtype) | ||||||
|
||||||
# Split the attention outputs. | ||||||
if encoder_hidden_states is not None: | ||||||
hidden_states, encoder_hidden_states = ( | ||||||
hidden_states[:, encoder_hidden_states.shape[1] :], | ||||||
hidden_states[:, : encoder_hidden_states.shape[1]], | ||||||
) | ||||||
|
||||||
# linear proj | ||||||
hidden_states = attn.to_out[0](hidden_states) | ||||||
# dropout | ||||||
hidden_states = attn.to_out[1](hidden_states) | ||||||
if encoder_hidden_states is not None: | ||||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states) | ||||||
|
||||||
if encoder_hidden_states is not None: | ||||||
return hidden_states, encoder_hidden_states | ||||||
else: | ||||||
return hidden_states | ||||||
|
||||||
|
||||||
class XFormersAttnAddedKVProcessor: | ||||||
r""" | ||||||
Processor for implementing memory efficient attention using xFormers. | ||||||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.