Skip to content

Attention processor cross attention norm group norm #3021

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

williamberman
Copy link
Contributor

@williamberman williamberman commented Apr 8, 2023

This lets the cross attention norm use both a group norm block or a
layer norm block.

The group norm operates along the channels dimension
and requires input shape (batch size, channels, ) where as the layer norm with a single
normalized_shape dimension only operates over the least significant
dimension i.e. (
, channels).

The channels we want to normalize along are the hidden dimension of the encoder hidden states.

By convention, the encoder hidden states are always passed as (batch size, sequence
length, hidden states).

This means the layer norm can operate on the tensor without modification, but the group
norm requires flipping the last two dimensions to operate on (batch size, hidden states, sequence length).

All existing attention processors will have the same logic and we can
consolidate it in a helper function prepare_encoder_hidden_states

this is rebased on top of #3014 because it requires the bug fix of making sure the added kv processors always take encoder states as (batch size, sequence
length, hidden states)

@@ -68,7 +69,6 @@ def __init__(
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.cross_attention_norm = cross_attention_norm
Copy link
Contributor Author

@williamberman williamberman Apr 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a layer's existence is configured with a flag, we usually define the layer as none or defined, and then check if the layer is not none in the forward method. I consolidated this check in the prepare_encoder_hidden_states method and removed all references to the instance variable

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok for me

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 8, 2023

The documentation is not available anymore as the PR was closed or merged.

@@ -56,7 +56,8 @@ def __init__(
bias=False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
cross_attention_norm: bool = False,
cross_attention_norm: Union[bool, str] = False,
cross_attention_norm_num_groups: int = 32,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iirc, group norms in attention blocks for ldm codebases/variants almost always use 32 groups. We probably won't have to actually set this value often

@williamberman williamberman force-pushed the attention_processor_cross_attention_norm_group_norm branch 4 times, most recently from 0670b76 to c3307c3 Compare April 8, 2023 22:48
Comment on lines 314 to 345
def prepare_encoder_hidden_states(self, hidden_states, encoder_hidden_states=None):
if encoder_hidden_states is None:
return hidden_states

if self.norm_cross is None:
return encoder_hidden_states

if isinstance(self.norm_cross, nn.LayerNorm):
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
elif isinstance(self.norm_cross, nn.GroupNorm):
# Group norm norms along the channels dimension and expects
# input to be in the shape of (N, C, *). In this case, we want
# to norm along the hidden dimension, so we need to move
# (batch_size, sequence_length, hidden_size) ->
# (batch_size, hidden_size, sequence_length)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
else:
assert False

return encoder_hidden_states

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Standardized encoder hidden states preprocessing

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this method renamed to norm_encoder_hidden_states()?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind.

9160e51

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yessir!

@williamberman williamberman force-pushed the attention_processor_cross_attention_norm_group_norm branch from c3307c3 to 1a90bf6 Compare April 8, 2023 23:11
@williamberman williamberman force-pushed the attention_processor_cross_attention_norm_group_norm branch from 1a90bf6 to 1dd2774 Compare April 8, 2023 23:21
@@ -56,7 +56,8 @@ def __init__(
bias=False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
cross_attention_norm: bool = False,
cross_attention_norm: Optional[str] = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be safely upgraded to an Optional[str] as we only ever use it internally through the "K attention" blocks which are updated in this PR

@@ -291,6 +311,29 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None)
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
return attention_mask

def prepare_encoder_hidden_states(self, hidden_states, encoder_hidden_states=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a huge fan of the naming here. Can we call it maybe instead:

Suggested change
def prepare_encoder_hidden_states(self, hidden_states, encoder_hidden_states=None):
def norm_encoder_hidden_states(self, hidden_states, encoder_hidden_states=None):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed here and moved to the check on if encoder_hidden_states is None back to the attention processors to better reflect the new name 9160e51

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for me, but let's maybe give it a different name.
Also, let's maybe make sure to not factor out to much logic as it makes the function much harder to read

@williamberman williamberman force-pushed the attention_processor_cross_attention_norm_group_norm branch 2 times, most recently from 2084154 to 9160e51 Compare April 11, 2023 00:21
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice :)

@williamberman williamberman force-pushed the attention_processor_cross_attention_norm_group_norm branch from 9160e51 to 3cf771f Compare April 11, 2023 17:39
This lets the cross attention norm use both a group norm block and a
layer norm block.

The group norm operates along the channels dimension
and requires input shape (batch size, channels, *) where as the layer norm with a single
`normalized_shape` dimension only operates over the least significant
dimension i.e. (*, channels).

The channels we want to normalize are the hidden dimension of the encoder hidden states.

By convention, the encoder hidden states are always passed as (batch size, sequence
length, hidden states).

This means the layer norm can operate on the tensor without modification, but the group
norm requires flipping the last two dimensions to operate on (batch size, hidden states, sequence length).

All existing attention processors will have the same logic and we can
consolidate it in a helper function `prepare_encoder_hidden_states`

prepare_encoder_hidden_states -> norm_encoder_hidden_states re: @patrickvonplaten

move norm_cross defined check to outside norm_encoder_hidden_states

add missing attn.norm_cross check
@williamberman williamberman force-pushed the attention_processor_cross_attention_norm_group_norm branch from 7ac3eff to e35de4b Compare April 11, 2023 21:43
@williamberman williamberman merged commit 98c5e5d into huggingface:main Apr 11, 2023
w4ffl35 pushed a commit to w4ffl35/diffusers that referenced this pull request Apr 14, 2023
add group norm type to attention processor cross attention norm

This lets the cross attention norm use both a group norm block and a
layer norm block.

The group norm operates along the channels dimension
and requires input shape (batch size, channels, *) where as the layer norm with a single
`normalized_shape` dimension only operates over the least significant
dimension i.e. (*, channels).

The channels we want to normalize are the hidden dimension of the encoder hidden states.

By convention, the encoder hidden states are always passed as (batch size, sequence
length, hidden states).

This means the layer norm can operate on the tensor without modification, but the group
norm requires flipping the last two dimensions to operate on (batch size, hidden states, sequence length).

All existing attention processors will have the same logic and we can
consolidate it in a helper function `prepare_encoder_hidden_states`

prepare_encoder_hidden_states -> norm_encoder_hidden_states re: @patrickvonplaten

move norm_cross defined check to outside norm_encoder_hidden_states

add missing attn.norm_cross check
dg845 pushed a commit to dg845/diffusers that referenced this pull request May 6, 2023
add group norm type to attention processor cross attention norm

This lets the cross attention norm use both a group norm block and a
layer norm block.

The group norm operates along the channels dimension
and requires input shape (batch size, channels, *) where as the layer norm with a single
`normalized_shape` dimension only operates over the least significant
dimension i.e. (*, channels).

The channels we want to normalize are the hidden dimension of the encoder hidden states.

By convention, the encoder hidden states are always passed as (batch size, sequence
length, hidden states).

This means the layer norm can operate on the tensor without modification, but the group
norm requires flipping the last two dimensions to operate on (batch size, hidden states, sequence length).

All existing attention processors will have the same logic and we can
consolidate it in a helper function `prepare_encoder_hidden_states`

prepare_encoder_hidden_states -> norm_encoder_hidden_states re: @patrickvonplaten

move norm_cross defined check to outside norm_encoder_hidden_states

add missing attn.norm_cross check
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
add group norm type to attention processor cross attention norm

This lets the cross attention norm use both a group norm block and a
layer norm block.

The group norm operates along the channels dimension
and requires input shape (batch size, channels, *) where as the layer norm with a single
`normalized_shape` dimension only operates over the least significant
dimension i.e. (*, channels).

The channels we want to normalize are the hidden dimension of the encoder hidden states.

By convention, the encoder hidden states are always passed as (batch size, sequence
length, hidden states).

This means the layer norm can operate on the tensor without modification, but the group
norm requires flipping the last two dimensions to operate on (batch size, hidden states, sequence length).

All existing attention processors will have the same logic and we can
consolidate it in a helper function `prepare_encoder_hidden_states`

prepare_encoder_hidden_states -> norm_encoder_hidden_states re: @patrickvonplaten

move norm_cross defined check to outside norm_encoder_hidden_states

add missing attn.norm_cross check
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
add group norm type to attention processor cross attention norm

This lets the cross attention norm use both a group norm block and a
layer norm block.

The group norm operates along the channels dimension
and requires input shape (batch size, channels, *) where as the layer norm with a single
`normalized_shape` dimension only operates over the least significant
dimension i.e. (*, channels).

The channels we want to normalize are the hidden dimension of the encoder hidden states.

By convention, the encoder hidden states are always passed as (batch size, sequence
length, hidden states).

This means the layer norm can operate on the tensor without modification, but the group
norm requires flipping the last two dimensions to operate on (batch size, hidden states, sequence length).

All existing attention processors will have the same logic and we can
consolidate it in a helper function `prepare_encoder_hidden_states`

prepare_encoder_hidden_states -> norm_encoder_hidden_states re: @patrickvonplaten

move norm_cross defined check to outside norm_encoder_hidden_states

add missing attn.norm_cross check
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants