-
Notifications
You must be signed in to change notification settings - Fork 6k
Attention processor cross attention norm group norm #3021
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -56,7 +56,8 @@ def __init__( | |
bias=False, | ||
upcast_attention: bool = False, | ||
upcast_softmax: bool = False, | ||
cross_attention_norm: bool = False, | ||
cross_attention_norm: Optional[str] = None, | ||
cross_attention_norm_num_groups: int = 32, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. iirc, group norms in attention blocks for ldm codebases/variants almost always use 32 groups. We probably won't have to actually set this value often |
||
added_kv_proj_dim: Optional[int] = None, | ||
norm_num_groups: Optional[int] = None, | ||
out_bias: bool = True, | ||
|
@@ -69,7 +70,6 @@ def __init__( | |
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim | ||
self.upcast_attention = upcast_attention | ||
self.upcast_softmax = upcast_softmax | ||
self.cross_attention_norm = cross_attention_norm | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If a layer's existence is configured with a flag, we usually define the layer as none or defined, and then check if the layer is not none in the forward method. I consolidated this check in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok for me |
||
|
||
self.scale = dim_head**-0.5 if scale_qk else 1.0 | ||
|
||
|
@@ -92,8 +92,28 @@ def __init__( | |
else: | ||
self.group_norm = None | ||
|
||
if cross_attention_norm: | ||
if cross_attention_norm is None: | ||
self.norm_cross = None | ||
elif cross_attention_norm == "layer_norm": | ||
self.norm_cross = nn.LayerNorm(cross_attention_dim) | ||
elif cross_attention_norm == "group_norm": | ||
if self.added_kv_proj_dim is not None: | ||
# The given `encoder_hidden_states` are initially of shape | ||
# (batch_size, seq_len, added_kv_proj_dim) before being projected | ||
# to (batch_size, seq_len, cross_attention_dim). The norm is applied | ||
# before the projection, so we need to use `added_kv_proj_dim` as | ||
# the number of channels for the group norm. | ||
norm_cross_num_channels = added_kv_proj_dim | ||
else: | ||
norm_cross_num_channels = cross_attention_dim | ||
|
||
self.norm_cross = nn.GroupNorm( | ||
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True | ||
) | ||
else: | ||
raise ValueError( | ||
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" | ||
) | ||
|
||
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) | ||
|
||
|
@@ -304,6 +324,25 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None) | |
attention_mask = attention_mask.repeat_interleave(head_size, dim=0) | ||
return attention_mask | ||
|
||
def norm_encoder_hidden_states(self, encoder_hidden_states): | ||
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" | ||
|
||
if isinstance(self.norm_cross, nn.LayerNorm): | ||
encoder_hidden_states = self.norm_cross(encoder_hidden_states) | ||
elif isinstance(self.norm_cross, nn.GroupNorm): | ||
# Group norm norms along the channels dimension and expects | ||
# input to be in the shape of (N, C, *). In this case, we want | ||
# to norm along the hidden dimension, so we need to move | ||
# (batch_size, sequence_length, hidden_size) -> | ||
# (batch_size, hidden_size, sequence_length) | ||
encoder_hidden_states = encoder_hidden_states.transpose(1, 2) | ||
encoder_hidden_states = self.norm_cross(encoder_hidden_states) | ||
encoder_hidden_states = encoder_hidden_states.transpose(1, 2) | ||
else: | ||
assert False | ||
|
||
return encoder_hidden_states | ||
|
||
|
||
class AttnProcessor: | ||
def __call__( | ||
|
@@ -321,8 +360,8 @@ def __call__( | |
|
||
if encoder_hidden_states is None: | ||
encoder_hidden_states = hidden_states | ||
elif attn.cross_attention_norm: | ||
encoder_hidden_states = attn.norm_cross(encoder_hidden_states) | ||
elif attn.norm_cross: | ||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | ||
|
||
key = attn.to_k(encoder_hidden_states) | ||
value = attn.to_v(encoder_hidden_states) | ||
|
@@ -388,7 +427,10 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a | |
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) | ||
query = attn.head_to_batch_dim(query) | ||
|
||
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states | ||
if encoder_hidden_states is None: | ||
encoder_hidden_states = hidden_states | ||
elif attn.norm_cross: | ||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | ||
|
||
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) | ||
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) | ||
|
@@ -416,6 +458,11 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a | |
|
||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | ||
|
||
if encoder_hidden_states is None: | ||
encoder_hidden_states = hidden_states | ||
elif attn.norm_cross: | ||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | ||
|
||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | ||
|
||
query = attn.to_q(hidden_states) | ||
|
@@ -467,8 +514,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a | |
|
||
if encoder_hidden_states is None: | ||
encoder_hidden_states = hidden_states | ||
elif attn.cross_attention_norm: | ||
encoder_hidden_states = attn.norm_cross(encoder_hidden_states) | ||
elif attn.norm_cross: | ||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | ||
|
||
key = attn.to_k(encoder_hidden_states) | ||
value = attn.to_v(encoder_hidden_states) | ||
|
@@ -511,8 +558,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a | |
|
||
if encoder_hidden_states is None: | ||
encoder_hidden_states = hidden_states | ||
elif attn.cross_attention_norm: | ||
encoder_hidden_states = attn.norm_cross(encoder_hidden_states) | ||
elif attn.norm_cross: | ||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | ||
|
||
key = attn.to_k(encoder_hidden_states) | ||
value = attn.to_v(encoder_hidden_states) | ||
|
@@ -561,7 +608,10 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a | |
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) | ||
query = attn.head_to_batch_dim(query).contiguous() | ||
|
||
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states | ||
if encoder_hidden_states is None: | ||
encoder_hidden_states = hidden_states | ||
elif attn.norm_cross: | ||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | ||
|
||
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) | ||
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) | ||
|
@@ -598,8 +648,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a | |
|
||
if encoder_hidden_states is None: | ||
encoder_hidden_states = hidden_states | ||
elif attn.cross_attention_norm: | ||
encoder_hidden_states = attn.norm_cross(encoder_hidden_states) | ||
elif attn.norm_cross: | ||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | ||
|
||
key = attn.to_k(encoder_hidden_states) | ||
value = attn.to_v(encoder_hidden_states) | ||
|
@@ -647,6 +697,11 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, | |
|
||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | ||
|
||
if encoder_hidden_states is None: | ||
encoder_hidden_states = hidden_states | ||
elif attn.norm_cross: | ||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | ||
|
||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | ||
|
||
query = attn.to_q(hidden_states) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be safely upgraded to an
Optional[str]
as we only ever use it internally through the "K attention" blocks which are updated in this PR