Skip to content

Commit 96e8671

Browse files
committed
add group norm type to attention processor cross attention norm
This lets the cross attention norm use both a group norm block and a layer norm block. The group norm operates along the channels dimension and requires input shape (batch size, channels, *) where as the layer norm with a single `normalized_shape` dimension only operates over the least significant dimension i.e. (*, channels). The channels we want to normalize are the hidden dimension of the encoder hidden states. By convention, the encoder hidden states are always passed as (batch size, sequence length, hidden states). This means the layer norm can operate on the tensor without modification, but the group norm requires flipping the last two dimensions to operate on (batch size, hidden states, sequence length). All existing attention processors will have the same logic and we can consolidate it in a helper function `prepare_encoder_hidden_states`
1 parent 67c3518 commit 96e8671

File tree

6 files changed

+80
-33
lines changed

6 files changed

+80
-33
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ def __init__(
5656
bias=False,
5757
upcast_attention: bool = False,
5858
upcast_softmax: bool = False,
59-
cross_attention_norm: bool = False,
59+
cross_attention_norm: Optional[str] = None,
60+
cross_attention_norm_num_groups: int = 32,
6061
added_kv_proj_dim: Optional[int] = None,
6162
norm_num_groups: Optional[int] = None,
6263
out_bias: bool = True,
@@ -68,7 +69,6 @@ def __init__(
6869
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
6970
self.upcast_attention = upcast_attention
7071
self.upcast_softmax = upcast_softmax
71-
self.cross_attention_norm = cross_attention_norm
7272

7373
self.scale = dim_head**-0.5 if scale_qk else 1.0
7474

@@ -85,8 +85,28 @@ def __init__(
8585
else:
8686
self.group_norm = None
8787

88-
if cross_attention_norm:
88+
if cross_attention_norm is None:
89+
self.norm_cross = None
90+
elif cross_attention_norm == "layer_norm":
8991
self.norm_cross = nn.LayerNorm(cross_attention_dim)
92+
elif cross_attention_norm == "group_norm":
93+
if self.added_kv_proj_dim is not None:
94+
# The given `encoder_hidden_states` are initially of shape
95+
# (batch_size, seq_len, added_kv_proj_dim) before being projected
96+
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
97+
# before the projection, so we need to use `added_kv_proj_dim` as
98+
# the number of channels for the group norm.
99+
norm_cross_num_channels = added_kv_proj_dim
100+
else:
101+
norm_cross_num_channels = cross_attention_dim
102+
103+
self.norm_cross = nn.GroupNorm(
104+
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
105+
)
106+
else:
107+
raise ValueError(
108+
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
109+
)
90110

91111
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
92112
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
@@ -291,6 +311,29 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None)
291311
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
292312
return attention_mask
293313

314+
def prepare_encoder_hidden_states(self, hidden_states, encoder_hidden_states=None):
315+
if encoder_hidden_states is None:
316+
return hidden_states
317+
318+
if self.norm_cross is None:
319+
return encoder_hidden_states
320+
321+
if isinstance(self.norm_cross, nn.LayerNorm):
322+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
323+
elif isinstance(self.norm_cross, nn.GroupNorm):
324+
# Group norm norms along the channels dimension and expects
325+
# input to be in the shape of (N, C, *). In this case, we want
326+
# to norm along the hidden dimension, so we need to move
327+
# (batch_size, sequence_length, hidden_size) ->
328+
# (batch_size, hidden_size, sequence_length)
329+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
330+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
331+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
332+
else:
333+
assert False
334+
335+
return encoder_hidden_states
336+
294337

295338
class AttnProcessor:
296339
def __call__(
@@ -306,10 +349,7 @@ def __call__(
306349
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
307350
query = attn.to_q(hidden_states)
308351

309-
if encoder_hidden_states is None:
310-
encoder_hidden_states = hidden_states
311-
elif attn.cross_attention_norm:
312-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
352+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
313353

314354
key = attn.to_k(encoder_hidden_states)
315355
value = attn.to_v(encoder_hidden_states)
@@ -375,7 +415,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
375415
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
376416
query = attn.head_to_batch_dim(query)
377417

378-
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
418+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
379419

380420
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
381421
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
@@ -402,6 +442,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
402442
batch_size, sequence_length, _ = hidden_states.shape
403443

404444
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
445+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
405446

406447
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
407448

@@ -449,10 +490,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
449490

450491
query = attn.to_q(hidden_states)
451492

452-
if encoder_hidden_states is None:
453-
encoder_hidden_states = hidden_states
454-
elif attn.cross_attention_norm:
455-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
493+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
456494

457495
key = attn.to_k(encoder_hidden_states)
458496
value = attn.to_v(encoder_hidden_states)
@@ -493,10 +531,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
493531

494532
query = attn.to_q(hidden_states)
495533

496-
if encoder_hidden_states is None:
497-
encoder_hidden_states = hidden_states
498-
elif attn.cross_attention_norm:
499-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
534+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
500535

501536
key = attn.to_k(encoder_hidden_states)
502537
value = attn.to_v(encoder_hidden_states)
@@ -545,7 +580,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
545580
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
546581
query = attn.head_to_batch_dim(query).contiguous()
547582

548-
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
583+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
549584

550585
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
551586
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
@@ -580,10 +615,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
580615
dim = query.shape[-1]
581616
query = attn.head_to_batch_dim(query)
582617

583-
if encoder_hidden_states is None:
584-
encoder_hidden_states = hidden_states
585-
elif attn.cross_attention_norm:
586-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
618+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
587619

588620
key = attn.to_k(encoder_hidden_states)
589621
value = attn.to_v(encoder_hidden_states)
@@ -630,6 +662,7 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
630662
batch_size, sequence_length, _ = hidden_states.shape
631663

632664
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
665+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
633666

634667
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
635668

src/diffusers/models/unet_2d_blocks.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def get_down_block(
4444
resnet_time_scale_shift="default",
4545
resnet_skip_time_act=False,
4646
resnet_out_scale_factor=1.0,
47+
cross_attention_norm=None,
4748
):
4849
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
4950
if down_block_type == "DownBlock2D":
@@ -125,6 +126,7 @@ def get_down_block(
125126
resnet_time_scale_shift=resnet_time_scale_shift,
126127
skip_time_act=resnet_skip_time_act,
127128
output_scale_factor=resnet_out_scale_factor,
129+
cross_attention_norm=cross_attention_norm,
128130
)
129131
elif down_block_type == "SkipDownBlock2D":
130132
return SkipDownBlock2D(
@@ -222,6 +224,7 @@ def get_up_block(
222224
resnet_time_scale_shift="default",
223225
resnet_skip_time_act=False,
224226
resnet_out_scale_factor=1.0,
227+
cross_attention_norm=None,
225228
):
226229
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
227230
if up_block_type == "UpBlock2D":
@@ -291,6 +294,7 @@ def get_up_block(
291294
resnet_time_scale_shift=resnet_time_scale_shift,
292295
skip_time_act=resnet_skip_time_act,
293296
output_scale_factor=resnet_out_scale_factor,
297+
cross_attention_norm=cross_attention_norm,
294298
)
295299
elif up_block_type == "AttnUpBlock2D":
296300
return AttnUpBlock2D(
@@ -575,6 +579,7 @@ def __init__(
575579
output_scale_factor=1.0,
576580
cross_attention_dim=1280,
577581
skip_time_act=False,
582+
cross_attention_norm=None,
578583
):
579584
super().__init__()
580585

@@ -614,6 +619,7 @@ def __init__(
614619
norm_num_groups=resnet_groups,
615620
bias=True,
616621
upcast_softmax=True,
622+
cross_attention_norm=cross_attention_norm,
617623
processor=AttnAddedKVProcessor(),
618624
)
619625
)
@@ -1356,6 +1362,7 @@ def __init__(
13561362
output_scale_factor=1.0,
13571363
add_downsample=True,
13581364
skip_time_act=False,
1365+
cross_attention_norm=None,
13591366
):
13601367
super().__init__()
13611368

@@ -1394,6 +1401,7 @@ def __init__(
13941401
norm_num_groups=resnet_groups,
13951402
bias=True,
13961403
upcast_softmax=True,
1404+
cross_attention_norm=cross_attention_norm,
13971405
processor=AttnAddedKVProcessor(),
13981406
)
13991407
)
@@ -1574,7 +1582,7 @@ def __init__(
15741582
temb_channels=temb_channels,
15751583
attention_bias=True,
15761584
add_self_attention=add_self_attention,
1577-
cross_attention_norm=True,
1585+
cross_attention_norm="layer_norm",
15781586
group_size=resnet_group_size,
15791587
)
15801588
)
@@ -2354,6 +2362,7 @@ def __init__(
23542362
output_scale_factor=1.0,
23552363
add_upsample=True,
23562364
skip_time_act=False,
2365+
cross_attention_norm=None,
23572366
):
23582367
super().__init__()
23592368
resnets = []
@@ -2393,6 +2402,7 @@ def __init__(
23932402
norm_num_groups=resnet_groups,
23942403
bias=True,
23952404
upcast_softmax=True,
2405+
cross_attention_norm=cross_attention_norm,
23962406
processor=AttnAddedKVProcessor(),
23972407
)
23982408
)
@@ -2600,7 +2610,7 @@ def __init__(
26002610
temb_channels=temb_channels,
26012611
attention_bias=True,
26022612
add_self_attention=add_self_attention,
2603-
cross_attention_norm=True,
2613+
cross_attention_norm="layer_norm",
26042614
upcast_attention=upcast_attention,
26052615
)
26062616
)
@@ -2695,7 +2705,7 @@ def __init__(
26952705
upcast_attention: bool = False,
26962706
temb_channels: int = 768, # for ada_group_norm
26972707
add_self_attention: bool = False,
2698-
cross_attention_norm: bool = False,
2708+
cross_attention_norm: Optional[str] = None,
26992709
group_size: int = 32,
27002710
):
27012711
super().__init__()
@@ -2711,7 +2721,7 @@ def __init__(
27112721
dropout=dropout,
27122722
bias=attention_bias,
27132723
cross_attention_dim=None,
2714-
cross_attention_norm=False,
2724+
cross_attention_norm=None,
27152725
)
27162726

27172727
# 2. Cross-Attn

src/diffusers/models/unet_2d_condition.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def __init__(
158158
conv_out_kernel: int = 3,
159159
projection_class_embeddings_input_dim: Optional[int] = None,
160160
class_embeddings_concat: bool = False,
161+
cross_attention_norm: Optional[str] = None,
161162
):
162163
super().__init__()
163164

@@ -311,6 +312,7 @@ def __init__(
311312
resnet_time_scale_shift=resnet_time_scale_shift,
312313
resnet_skip_time_act=resnet_skip_time_act,
313314
resnet_out_scale_factor=resnet_out_scale_factor,
315+
cross_attention_norm=cross_attention_norm,
314316
)
315317
self.down_blocks.append(down_block)
316318

@@ -342,6 +344,7 @@ def __init__(
342344
resnet_groups=norm_num_groups,
343345
resnet_time_scale_shift=resnet_time_scale_shift,
344346
skip_time_act=resnet_skip_time_act,
347+
cross_attention_norm=cross_attention_norm,
345348
)
346349
elif mid_block_type is None:
347350
self.mid_block = None
@@ -393,6 +396,7 @@ def __init__(
393396
resnet_time_scale_shift=resnet_time_scale_shift,
394397
resnet_skip_time_act=resnet_skip_time_act,
395398
resnet_out_scale_factor=resnet_out_scale_factor,
399+
cross_attention_norm=cross_attention_norm,
396400
)
397401
self.up_blocks.append(up_block)
398402
prev_output_channel = output_channel

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,10 +241,7 @@ def __call__(
241241
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
242242
query = attn.to_q(hidden_states)
243243

244-
if encoder_hidden_states is None:
245-
encoder_hidden_states = hidden_states
246-
elif attn.cross_attention_norm:
247-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
244+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
248245

249246
key = attn.to_k(encoder_hidden_states)
250247
value = attn.to_v(encoder_hidden_states)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,7 @@ def __call__(
6363
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
6464
query = attn.to_q(hidden_states)
6565

66-
if encoder_hidden_states is None:
67-
encoder_hidden_states = hidden_states
68-
elif attn.cross_attention_norm:
69-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
66+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
7067

7168
key = attn.to_k(encoder_hidden_states)
7269
value = attn.to_v(encoder_hidden_states)

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ def __init__(
244244
conv_out_kernel: int = 3,
245245
projection_class_embeddings_input_dim: Optional[int] = None,
246246
class_embeddings_concat: bool = False,
247+
cross_attention_norm: Optional[str] = None,
247248
):
248249
super().__init__()
249250

@@ -403,6 +404,7 @@ def __init__(
403404
resnet_time_scale_shift=resnet_time_scale_shift,
404405
resnet_skip_time_act=resnet_skip_time_act,
405406
resnet_out_scale_factor=resnet_out_scale_factor,
407+
cross_attention_norm=cross_attention_norm,
406408
)
407409
self.down_blocks.append(down_block)
408410

@@ -434,6 +436,7 @@ def __init__(
434436
resnet_groups=norm_num_groups,
435437
resnet_time_scale_shift=resnet_time_scale_shift,
436438
skip_time_act=resnet_skip_time_act,
439+
cross_attention_norm=cross_attention_norm,
437440
)
438441
elif mid_block_type is None:
439442
self.mid_block = None
@@ -485,6 +488,7 @@ def __init__(
485488
resnet_time_scale_shift=resnet_time_scale_shift,
486489
resnet_skip_time_act=resnet_skip_time_act,
487490
resnet_out_scale_factor=resnet_out_scale_factor,
491+
cross_attention_norm=cross_attention_norm,
488492
)
489493
self.up_blocks.append(up_block)
490494
prev_output_channel = output_channel
@@ -1463,6 +1467,7 @@ def __init__(
14631467
output_scale_factor=1.0,
14641468
cross_attention_dim=1280,
14651469
skip_time_act=False,
1470+
cross_attention_norm=None,
14661471
):
14671472
super().__init__()
14681473

@@ -1502,6 +1507,7 @@ def __init__(
15021507
norm_num_groups=resnet_groups,
15031508
bias=True,
15041509
upcast_softmax=True,
1510+
cross_attention_norm=cross_attention_norm,
15051511
processor=AttnAddedKVProcessor(),
15061512
)
15071513
)

0 commit comments

Comments
 (0)