-
Notifications
You must be signed in to change notification settings - Fork 6k
Attention processor cross attention norm group norm #3021
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Attention processor cross attention norm group norm #3021
Conversation
@@ -68,7 +69,6 @@ def __init__( | |||
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim | |||
self.upcast_attention = upcast_attention | |||
self.upcast_softmax = upcast_softmax | |||
self.cross_attention_norm = cross_attention_norm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If a layer's existence is configured with a flag, we usually define the layer as none or defined, and then check if the layer is not none in the forward method. I consolidated this check in the prepare_encoder_hidden_states
method and removed all references to the instance variable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok for me
The documentation is not available anymore as the PR was closed or merged. |
@@ -56,7 +56,8 @@ def __init__( | |||
bias=False, | |||
upcast_attention: bool = False, | |||
upcast_softmax: bool = False, | |||
cross_attention_norm: bool = False, | |||
cross_attention_norm: Union[bool, str] = False, | |||
cross_attention_norm_num_groups: int = 32, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
iirc, group norms in attention blocks for ldm codebases/variants almost always use 32 groups. We probably won't have to actually set this value often
0670b76
to
c3307c3
Compare
def prepare_encoder_hidden_states(self, hidden_states, encoder_hidden_states=None): | ||
if encoder_hidden_states is None: | ||
return hidden_states | ||
|
||
if self.norm_cross is None: | ||
return encoder_hidden_states | ||
|
||
if isinstance(self.norm_cross, nn.LayerNorm): | ||
encoder_hidden_states = self.norm_cross(encoder_hidden_states) | ||
elif isinstance(self.norm_cross, nn.GroupNorm): | ||
# Group norm norms along the channels dimension and expects | ||
# input to be in the shape of (N, C, *). In this case, we want | ||
# to norm along the hidden dimension, so we need to move | ||
# (batch_size, sequence_length, hidden_size) -> | ||
# (batch_size, hidden_size, sequence_length) | ||
encoder_hidden_states = encoder_hidden_states.transpose(1, 2) | ||
encoder_hidden_states = self.norm_cross(encoder_hidden_states) | ||
encoder_hidden_states = encoder_hidden_states.transpose(1, 2) | ||
else: | ||
assert False | ||
|
||
return encoder_hidden_states | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Standardized encoder hidden states preprocessing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this method renamed to norm_encoder_hidden_states()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nevermind.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yessir!
c3307c3
to
1a90bf6
Compare
1a90bf6
to
1dd2774
Compare
@@ -56,7 +56,8 @@ def __init__( | |||
bias=False, | |||
upcast_attention: bool = False, | |||
upcast_softmax: bool = False, | |||
cross_attention_norm: bool = False, | |||
cross_attention_norm: Optional[str] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be safely upgraded to an Optional[str]
as we only ever use it internally through the "K attention" blocks which are updated in this PR
@@ -291,6 +311,29 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None) | |||
attention_mask = attention_mask.repeat_interleave(head_size, dim=0) | |||
return attention_mask | |||
|
|||
def prepare_encoder_hidden_states(self, hidden_states, encoder_hidden_states=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not a huge fan of the naming here. Can we call it maybe instead:
def prepare_encoder_hidden_states(self, hidden_states, encoder_hidden_states=None): | |
def norm_encoder_hidden_states(self, hidden_states, encoder_hidden_states=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
renamed here and moved to the check on if encoder_hidden_states is None back to the attention processors to better reflect the new name 9160e51
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok for me, but let's maybe give it a different name.
Also, let's maybe make sure to not factor out to much logic as it makes the function much harder to read
2084154
to
9160e51
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice :)
9160e51
to
3cf771f
Compare
This lets the cross attention norm use both a group norm block and a layer norm block. The group norm operates along the channels dimension and requires input shape (batch size, channels, *) where as the layer norm with a single `normalized_shape` dimension only operates over the least significant dimension i.e. (*, channels). The channels we want to normalize are the hidden dimension of the encoder hidden states. By convention, the encoder hidden states are always passed as (batch size, sequence length, hidden states). This means the layer norm can operate on the tensor without modification, but the group norm requires flipping the last two dimensions to operate on (batch size, hidden states, sequence length). All existing attention processors will have the same logic and we can consolidate it in a helper function `prepare_encoder_hidden_states` prepare_encoder_hidden_states -> norm_encoder_hidden_states re: @patrickvonplaten move norm_cross defined check to outside norm_encoder_hidden_states add missing attn.norm_cross check
7ac3eff
to
e35de4b
Compare
add group norm type to attention processor cross attention norm This lets the cross attention norm use both a group norm block and a layer norm block. The group norm operates along the channels dimension and requires input shape (batch size, channels, *) where as the layer norm with a single `normalized_shape` dimension only operates over the least significant dimension i.e. (*, channels). The channels we want to normalize are the hidden dimension of the encoder hidden states. By convention, the encoder hidden states are always passed as (batch size, sequence length, hidden states). This means the layer norm can operate on the tensor without modification, but the group norm requires flipping the last two dimensions to operate on (batch size, hidden states, sequence length). All existing attention processors will have the same logic and we can consolidate it in a helper function `prepare_encoder_hidden_states` prepare_encoder_hidden_states -> norm_encoder_hidden_states re: @patrickvonplaten move norm_cross defined check to outside norm_encoder_hidden_states add missing attn.norm_cross check
add group norm type to attention processor cross attention norm This lets the cross attention norm use both a group norm block and a layer norm block. The group norm operates along the channels dimension and requires input shape (batch size, channels, *) where as the layer norm with a single `normalized_shape` dimension only operates over the least significant dimension i.e. (*, channels). The channels we want to normalize are the hidden dimension of the encoder hidden states. By convention, the encoder hidden states are always passed as (batch size, sequence length, hidden states). This means the layer norm can operate on the tensor without modification, but the group norm requires flipping the last two dimensions to operate on (batch size, hidden states, sequence length). All existing attention processors will have the same logic and we can consolidate it in a helper function `prepare_encoder_hidden_states` prepare_encoder_hidden_states -> norm_encoder_hidden_states re: @patrickvonplaten move norm_cross defined check to outside norm_encoder_hidden_states add missing attn.norm_cross check
add group norm type to attention processor cross attention norm This lets the cross attention norm use both a group norm block and a layer norm block. The group norm operates along the channels dimension and requires input shape (batch size, channels, *) where as the layer norm with a single `normalized_shape` dimension only operates over the least significant dimension i.e. (*, channels). The channels we want to normalize are the hidden dimension of the encoder hidden states. By convention, the encoder hidden states are always passed as (batch size, sequence length, hidden states). This means the layer norm can operate on the tensor without modification, but the group norm requires flipping the last two dimensions to operate on (batch size, hidden states, sequence length). All existing attention processors will have the same logic and we can consolidate it in a helper function `prepare_encoder_hidden_states` prepare_encoder_hidden_states -> norm_encoder_hidden_states re: @patrickvonplaten move norm_cross defined check to outside norm_encoder_hidden_states add missing attn.norm_cross check
add group norm type to attention processor cross attention norm This lets the cross attention norm use both a group norm block and a layer norm block. The group norm operates along the channels dimension and requires input shape (batch size, channels, *) where as the layer norm with a single `normalized_shape` dimension only operates over the least significant dimension i.e. (*, channels). The channels we want to normalize are the hidden dimension of the encoder hidden states. By convention, the encoder hidden states are always passed as (batch size, sequence length, hidden states). This means the layer norm can operate on the tensor without modification, but the group norm requires flipping the last two dimensions to operate on (batch size, hidden states, sequence length). All existing attention processors will have the same logic and we can consolidate it in a helper function `prepare_encoder_hidden_states` prepare_encoder_hidden_states -> norm_encoder_hidden_states re: @patrickvonplaten move norm_cross defined check to outside norm_encoder_hidden_states add missing attn.norm_cross check
This lets the cross attention norm use both a group norm block or a
layer norm block.
The group norm operates along the channels dimension
and requires input shape (batch size, channels, ) where as the layer norm with a single
normalized_shape
dimension only operates over the least significantdimension i.e. (, channels).
The channels we want to normalize along are the hidden dimension of the encoder hidden states.
By convention, the encoder hidden states are always passed as (batch size, sequence
length, hidden states).
This means the layer norm can operate on the tensor without modification, but the group
norm requires flipping the last two dimensions to operate on (batch size, hidden states, sequence length).
All existing attention processors will have the same logic and we can
consolidate it in a helper function
prepare_encoder_hidden_states
this is rebased on top of #3014 because it requires the bug fix of making sure the added kv processors always take encoder states as (batch size, sequence
length, hidden states)