Skip to content

Attention processor cross attention norm group norm #3021

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 68 additions & 13 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def __init__(
bias=False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
cross_attention_norm: bool = False,
cross_attention_norm: Optional[str] = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be safely upgraded to an Optional[str] as we only ever use it internally through the "K attention" blocks which are updated in this PR

cross_attention_norm_num_groups: int = 32,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iirc, group norms in attention blocks for ldm codebases/variants almost always use 32 groups. We probably won't have to actually set this value often

added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None,
out_bias: bool = True,
Expand All @@ -69,7 +70,6 @@ def __init__(
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.cross_attention_norm = cross_attention_norm
Copy link
Contributor Author

@williamberman williamberman Apr 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a layer's existence is configured with a flag, we usually define the layer as none or defined, and then check if the layer is not none in the forward method. I consolidated this check in the prepare_encoder_hidden_states method and removed all references to the instance variable

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok for me


self.scale = dim_head**-0.5 if scale_qk else 1.0

Expand All @@ -92,8 +92,28 @@ def __init__(
else:
self.group_norm = None

if cross_attention_norm:
if cross_attention_norm is None:
self.norm_cross = None
elif cross_attention_norm == "layer_norm":
self.norm_cross = nn.LayerNorm(cross_attention_dim)
elif cross_attention_norm == "group_norm":
if self.added_kv_proj_dim is not None:
# The given `encoder_hidden_states` are initially of shape
# (batch_size, seq_len, added_kv_proj_dim) before being projected
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
# before the projection, so we need to use `added_kv_proj_dim` as
# the number of channels for the group norm.
norm_cross_num_channels = added_kv_proj_dim
else:
norm_cross_num_channels = cross_attention_dim

self.norm_cross = nn.GroupNorm(
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
)
else:
raise ValueError(
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
)

self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)

Expand Down Expand Up @@ -304,6 +324,25 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None)
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
return attention_mask

def norm_encoder_hidden_states(self, encoder_hidden_states):
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"

if isinstance(self.norm_cross, nn.LayerNorm):
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
elif isinstance(self.norm_cross, nn.GroupNorm):
# Group norm norms along the channels dimension and expects
# input to be in the shape of (N, C, *). In this case, we want
# to norm along the hidden dimension, so we need to move
# (batch_size, sequence_length, hidden_size) ->
# (batch_size, hidden_size, sequence_length)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
else:
assert False

return encoder_hidden_states


class AttnProcessor:
def __call__(
Expand All @@ -321,8 +360,8 @@ def __call__(

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
Expand Down Expand Up @@ -388,7 +427,10 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
query = attn.head_to_batch_dim(query)

encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
Expand Down Expand Up @@ -416,6 +458,11 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a

attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states)
Expand Down Expand Up @@ -467,8 +514,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
Expand Down Expand Up @@ -511,8 +558,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
Expand Down Expand Up @@ -561,7 +608,10 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
query = attn.head_to_batch_dim(query).contiguous()

encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
Expand Down Expand Up @@ -598,8 +648,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
Expand Down Expand Up @@ -647,6 +697,11 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,

attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states)
Expand Down
18 changes: 14 additions & 4 deletions src/diffusers/models/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def get_down_block(
resnet_time_scale_shift="default",
resnet_skip_time_act=False,
resnet_out_scale_factor=1.0,
cross_attention_norm=None,
):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlock2D":
Expand Down Expand Up @@ -126,6 +127,7 @@ def get_down_block(
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
)
elif down_block_type == "SkipDownBlock2D":
return SkipDownBlock2D(
Expand Down Expand Up @@ -223,6 +225,7 @@ def get_up_block(
resnet_time_scale_shift="default",
resnet_skip_time_act=False,
resnet_out_scale_factor=1.0,
cross_attention_norm=None,
):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlock2D":
Expand Down Expand Up @@ -293,6 +296,7 @@ def get_up_block(
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
)
elif up_block_type == "AttnUpBlock2D":
return AttnUpBlock2D(
Expand Down Expand Up @@ -578,6 +582,7 @@ def __init__(
cross_attention_dim=1280,
skip_time_act=False,
only_cross_attention=False,
cross_attention_norm=None,
):
super().__init__()

Expand Down Expand Up @@ -618,6 +623,7 @@ def __init__(
bias=True,
upcast_softmax=True,
only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
processor=AttnAddedKVProcessor(),
)
)
Expand Down Expand Up @@ -1361,6 +1367,7 @@ def __init__(
add_downsample=True,
skip_time_act=False,
only_cross_attention=False,
cross_attention_norm=None,
):
super().__init__()

Expand Down Expand Up @@ -1400,6 +1407,7 @@ def __init__(
bias=True,
upcast_softmax=True,
only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
processor=AttnAddedKVProcessor(),
)
)
Expand Down Expand Up @@ -1580,7 +1588,7 @@ def __init__(
temb_channels=temb_channels,
attention_bias=True,
add_self_attention=add_self_attention,
cross_attention_norm=True,
cross_attention_norm="layer_norm",
group_size=resnet_group_size,
)
)
Expand Down Expand Up @@ -2361,6 +2369,7 @@ def __init__(
add_upsample=True,
skip_time_act=False,
only_cross_attention=False,
cross_attention_norm=None,
):
super().__init__()
resnets = []
Expand Down Expand Up @@ -2401,6 +2410,7 @@ def __init__(
bias=True,
upcast_softmax=True,
only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
processor=AttnAddedKVProcessor(),
)
)
Expand Down Expand Up @@ -2608,7 +2618,7 @@ def __init__(
temb_channels=temb_channels,
attention_bias=True,
add_self_attention=add_self_attention,
cross_attention_norm=True,
cross_attention_norm="layer_norm",
upcast_attention=upcast_attention,
)
)
Expand Down Expand Up @@ -2703,7 +2713,7 @@ def __init__(
upcast_attention: bool = False,
temb_channels: int = 768, # for ada_group_norm
add_self_attention: bool = False,
cross_attention_norm: bool = False,
cross_attention_norm: Optional[str] = None,
group_size: int = 32,
):
super().__init__()
Expand All @@ -2719,7 +2729,7 @@ def __init__(
dropout=dropout,
bias=attention_bias,
cross_attention_dim=None,
cross_attention_norm=False,
cross_attention_norm=None,
)

# 2. Cross-Attn
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def __init__(
projection_class_embeddings_input_dim: Optional[int] = None,
class_embeddings_concat: bool = False,
mid_block_only_cross_attention: Optional[bool] = None,
cross_attention_norm: Optional[str] = None,
):
super().__init__()

Expand Down Expand Up @@ -323,6 +324,7 @@ def __init__(
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
)
self.down_blocks.append(down_block)

Expand Down Expand Up @@ -355,6 +357,7 @@ def __init__(
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
only_cross_attention=mid_block_only_cross_attention,
cross_attention_norm=cross_attention_norm,
)
elif mid_block_type is None:
self.mid_block = None
Expand Down Expand Up @@ -406,6 +409,7 @@ def __init__(
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ def __call__(

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def __call__(

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def __init__(
projection_class_embeddings_input_dim: Optional[int] = None,
class_embeddings_concat: bool = False,
mid_block_only_cross_attention: Optional[bool] = None,
cross_attention_norm: Optional[str] = None,
):
super().__init__()

Expand Down Expand Up @@ -415,6 +416,7 @@ def __init__(
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
)
self.down_blocks.append(down_block)

Expand Down Expand Up @@ -447,6 +449,7 @@ def __init__(
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
only_cross_attention=mid_block_only_cross_attention,
cross_attention_norm=cross_attention_norm,
)
elif mid_block_type is None:
self.mid_block = None
Expand Down Expand Up @@ -498,6 +501,7 @@ def __init__(
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
Expand Down Expand Up @@ -1490,6 +1494,7 @@ def __init__(
cross_attention_dim=1280,
skip_time_act=False,
only_cross_attention=False,
cross_attention_norm=None,
):
super().__init__()

Expand Down Expand Up @@ -1530,6 +1535,7 @@ def __init__(
bias=True,
upcast_softmax=True,
only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
processor=AttnAddedKVProcessor(),
)
)
Expand Down