Skip to content

Commit 1dd2774

Browse files
committed
add group norm type to attention processor cross attention norm
This lets the cross attention norm use both a group norm block and a layer norm block. The group norm operates along the channels dimension and requires input shape (batch size, channels, *) where as the layer norm with a single `normalized_shape` dimension only operates over the least significant dimension i.e. (*, channels). The channels we want to normalize are the hidden dimension of the encoder hidden states. By convention, the encoder hidden states are always passed as (batch size, sequence length, hidden states). This means the layer norm can operate on the tensor without modification, but the group norm requires flipping the last two dimensions to operate on (batch size, hidden states, sequence length). All existing attention processors will have the same logic and we can consolidate it in a helper function `prepare_encoder_hidden_states`
1 parent b8d88b8 commit 1dd2774

File tree

6 files changed

+80
-33
lines changed

6 files changed

+80
-33
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ def __init__(
5656
bias=False,
5757
upcast_attention: bool = False,
5858
upcast_softmax: bool = False,
59-
cross_attention_norm: bool = False,
59+
cross_attention_norm: Optional[str] = None,
60+
cross_attention_norm_num_groups: int = 32,
6061
added_kv_proj_dim: Optional[int] = None,
6162
norm_num_groups: Optional[int] = None,
6263
out_bias: bool = True,
@@ -68,7 +69,6 @@ def __init__(
6869
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
6970
self.upcast_attention = upcast_attention
7071
self.upcast_softmax = upcast_softmax
71-
self.cross_attention_norm = cross_attention_norm
7272

7373
self.scale = dim_head**-0.5 if scale_qk else 1.0
7474

@@ -85,8 +85,28 @@ def __init__(
8585
else:
8686
self.group_norm = None
8787

88-
if cross_attention_norm:
88+
if cross_attention_norm is None:
89+
self.norm_cross = None
90+
elif cross_attention_norm == "layer_norm":
8991
self.norm_cross = nn.LayerNorm(cross_attention_dim)
92+
elif cross_attention_norm == "group_norm":
93+
if self.added_kv_proj_dim is not None:
94+
# The given `encoder_hidden_states` are initially of shape
95+
# (batch_size, seq_len, added_kv_proj_dim) before being projected
96+
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
97+
# before the projection, so we need to use `added_kv_proj_dim` as
98+
# the number of channels for the group norm.
99+
norm_cross_num_channels = added_kv_proj_dim
100+
else:
101+
norm_cross_num_channels = cross_attention_dim
102+
103+
self.norm_cross = nn.GroupNorm(
104+
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
105+
)
106+
else:
107+
raise ValueError(
108+
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
109+
)
90110

91111
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
92112
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
@@ -291,6 +311,29 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None)
291311
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
292312
return attention_mask
293313

314+
def prepare_encoder_hidden_states(self, hidden_states, encoder_hidden_states=None):
315+
if encoder_hidden_states is None:
316+
return hidden_states
317+
318+
if self.norm_cross is None:
319+
return encoder_hidden_states
320+
321+
if isinstance(self.norm_cross, nn.LayerNorm):
322+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
323+
elif isinstance(self.norm_cross, nn.GroupNorm):
324+
# Group norm norms along the channels dimension and expects
325+
# input to be in the shape of (N, C, *). In this case, we want
326+
# to norm along the hidden dimension, so we need to move
327+
# (batch_size, sequence_length, hidden_size) ->
328+
# (batch_size, hidden_size, sequence_length)
329+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
330+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
331+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
332+
else:
333+
assert False
334+
335+
return encoder_hidden_states
336+
294337

295338
class AttnProcessor:
296339
def __call__(
@@ -306,10 +349,7 @@ def __call__(
306349
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
307350
query = attn.to_q(hidden_states)
308351

309-
if encoder_hidden_states is None:
310-
encoder_hidden_states = hidden_states
311-
elif attn.cross_attention_norm:
312-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
352+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
313353

314354
key = attn.to_k(encoder_hidden_states)
315355
value = attn.to_v(encoder_hidden_states)
@@ -375,7 +415,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
375415
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
376416
query = attn.head_to_batch_dim(query)
377417

378-
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
418+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
379419

380420
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
381421
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
@@ -402,6 +442,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
402442
batch_size, sequence_length, _ = hidden_states.shape
403443

404444
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
445+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
405446

406447
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
407448

@@ -449,10 +490,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
449490

450491
query = attn.to_q(hidden_states)
451492

452-
if encoder_hidden_states is None:
453-
encoder_hidden_states = hidden_states
454-
elif attn.cross_attention_norm:
455-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
493+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
456494

457495
key = attn.to_k(encoder_hidden_states)
458496
value = attn.to_v(encoder_hidden_states)
@@ -493,10 +531,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
493531

494532
query = attn.to_q(hidden_states)
495533

496-
if encoder_hidden_states is None:
497-
encoder_hidden_states = hidden_states
498-
elif attn.cross_attention_norm:
499-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
534+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
500535

501536
key = attn.to_k(encoder_hidden_states)
502537
value = attn.to_v(encoder_hidden_states)
@@ -545,7 +580,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
545580
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
546581
query = attn.head_to_batch_dim(query).contiguous()
547582

548-
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
583+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
549584

550585
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
551586
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
@@ -580,10 +615,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
580615
dim = query.shape[-1]
581616
query = attn.head_to_batch_dim(query)
582617

583-
if encoder_hidden_states is None:
584-
encoder_hidden_states = hidden_states
585-
elif attn.cross_attention_norm:
586-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
618+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
587619

588620
key = attn.to_k(encoder_hidden_states)
589621
value = attn.to_v(encoder_hidden_states)
@@ -630,6 +662,7 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
630662
batch_size, sequence_length, _ = hidden_states.shape
631663

632664
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
665+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
633666

634667
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
635668

src/diffusers/models/unet_2d_blocks.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def get_down_block(
4242
only_cross_attention=False,
4343
upcast_attention=False,
4444
resnet_time_scale_shift="default",
45+
cross_attention_norm=None,
4546
):
4647
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
4748
if down_block_type == "DownBlock2D":
@@ -119,6 +120,7 @@ def get_down_block(
119120
cross_attention_dim=cross_attention_dim,
120121
attn_num_head_channels=attn_num_head_channels,
121122
resnet_time_scale_shift=resnet_time_scale_shift,
123+
cross_attention_norm=cross_attention_norm,
122124
)
123125
elif down_block_type == "SkipDownBlock2D":
124126
return SkipDownBlock2D(
@@ -214,6 +216,7 @@ def get_up_block(
214216
only_cross_attention=False,
215217
upcast_attention=False,
216218
resnet_time_scale_shift="default",
219+
cross_attention_norm=None,
217220
):
218221
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
219222
if up_block_type == "UpBlock2D":
@@ -279,6 +282,7 @@ def get_up_block(
279282
cross_attention_dim=cross_attention_dim,
280283
attn_num_head_channels=attn_num_head_channels,
281284
resnet_time_scale_shift=resnet_time_scale_shift,
285+
cross_attention_norm=cross_attention_norm,
282286
)
283287
elif up_block_type == "AttnUpBlock2D":
284288
return AttnUpBlock2D(
@@ -562,6 +566,7 @@ def __init__(
562566
attn_num_head_channels=1,
563567
output_scale_factor=1.0,
564568
cross_attention_dim=1280,
569+
cross_attention_norm=None,
565570
):
566571
super().__init__()
567572

@@ -600,6 +605,7 @@ def __init__(
600605
norm_num_groups=resnet_groups,
601606
bias=True,
602607
upcast_softmax=True,
608+
cross_attention_norm=cross_attention_norm,
603609
processor=AttnAddedKVProcessor(),
604610
)
605611
)
@@ -1337,6 +1343,7 @@ def __init__(
13371343
cross_attention_dim=1280,
13381344
output_scale_factor=1.0,
13391345
add_downsample=True,
1346+
cross_attention_norm=None,
13401347
):
13411348
super().__init__()
13421349

@@ -1374,6 +1381,7 @@ def __init__(
13741381
norm_num_groups=resnet_groups,
13751382
bias=True,
13761383
upcast_softmax=True,
1384+
cross_attention_norm=cross_attention_norm,
13771385
processor=AttnAddedKVProcessor(),
13781386
)
13791387
)
@@ -1553,7 +1561,7 @@ def __init__(
15531561
temb_channels=temb_channels,
15541562
attention_bias=True,
15551563
add_self_attention=add_self_attention,
1556-
cross_attention_norm=True,
1564+
cross_attention_norm="layer_norm",
15571565
group_size=resnet_group_size,
15581566
)
15591567
)
@@ -2329,6 +2337,7 @@ def __init__(
23292337
cross_attention_dim=1280,
23302338
output_scale_factor=1.0,
23312339
add_upsample=True,
2340+
cross_attention_norm=None,
23322341
):
23332342
super().__init__()
23342343
resnets = []
@@ -2367,6 +2376,7 @@ def __init__(
23672376
norm_num_groups=resnet_groups,
23682377
bias=True,
23692378
upcast_softmax=True,
2379+
cross_attention_norm=cross_attention_norm,
23702380
processor=AttnAddedKVProcessor(),
23712381
)
23722382
)
@@ -2573,7 +2583,7 @@ def __init__(
25732583
temb_channels=temb_channels,
25742584
attention_bias=True,
25752585
add_self_attention=add_self_attention,
2576-
cross_attention_norm=True,
2586+
cross_attention_norm="layer_norm",
25772587
upcast_attention=upcast_attention,
25782588
)
25792589
)
@@ -2668,7 +2678,7 @@ def __init__(
26682678
upcast_attention: bool = False,
26692679
temb_channels: int = 768, # for ada_group_norm
26702680
add_self_attention: bool = False,
2671-
cross_attention_norm: bool = False,
2681+
cross_attention_norm: Optional[str] = None,
26722682
group_size: int = 32,
26732683
):
26742684
super().__init__()
@@ -2684,7 +2694,7 @@ def __init__(
26842694
dropout=dropout,
26852695
bias=attention_bias,
26862696
cross_attention_dim=None,
2687-
cross_attention_norm=False,
2697+
cross_attention_norm=None,
26882698
)
26892699

26902700
# 2. Cross-Attn

src/diffusers/models/unet_2d_condition.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def __init__(
153153
conv_out_kernel: int = 3,
154154
projection_class_embeddings_input_dim: Optional[int] = None,
155155
class_embeddings_concat: bool = False,
156+
cross_attention_norm: Optional[str] = None,
156157
):
157158
super().__init__()
158159

@@ -291,6 +292,7 @@ def __init__(
291292
only_cross_attention=only_cross_attention[i],
292293
upcast_attention=upcast_attention,
293294
resnet_time_scale_shift=resnet_time_scale_shift,
295+
cross_attention_norm=cross_attention_norm,
294296
)
295297
self.down_blocks.append(down_block)
296298

@@ -321,6 +323,7 @@ def __init__(
321323
attn_num_head_channels=attention_head_dim[-1],
322324
resnet_groups=norm_num_groups,
323325
resnet_time_scale_shift=resnet_time_scale_shift,
326+
cross_attention_norm=cross_attention_norm,
324327
)
325328
elif mid_block_type is None:
326329
self.mid_block = None
@@ -369,6 +372,7 @@ def __init__(
369372
only_cross_attention=only_cross_attention[i],
370373
upcast_attention=upcast_attention,
371374
resnet_time_scale_shift=resnet_time_scale_shift,
375+
cross_attention_norm=cross_attention_norm,
372376
)
373377
self.up_blocks.append(up_block)
374378
prev_output_channel = output_channel

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,10 +241,7 @@ def __call__(
241241
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
242242
query = attn.to_q(hidden_states)
243243

244-
if encoder_hidden_states is None:
245-
encoder_hidden_states = hidden_states
246-
elif attn.cross_attention_norm:
247-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
244+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
248245

249246
key = attn.to_k(encoder_hidden_states)
250247
value = attn.to_v(encoder_hidden_states)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,7 @@ def __call__(
6363
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
6464
query = attn.to_q(hidden_states)
6565

66-
if encoder_hidden_states is None:
67-
encoder_hidden_states = hidden_states
68-
elif attn.cross_attention_norm:
69-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
66+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
7067

7168
key = attn.to_k(encoder_hidden_states)
7269
value = attn.to_v(encoder_hidden_states)

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ def __init__(
239239
conv_out_kernel: int = 3,
240240
projection_class_embeddings_input_dim: Optional[int] = None,
241241
class_embeddings_concat: bool = False,
242+
cross_attention_norm: Optional[str] = None,
242243
):
243244
super().__init__()
244245

@@ -382,6 +383,7 @@ def __init__(
382383
only_cross_attention=only_cross_attention[i],
383384
upcast_attention=upcast_attention,
384385
resnet_time_scale_shift=resnet_time_scale_shift,
386+
cross_attention_norm=cross_attention_norm,
385387
)
386388
self.down_blocks.append(down_block)
387389

@@ -412,6 +414,7 @@ def __init__(
412414
attn_num_head_channels=attention_head_dim[-1],
413415
resnet_groups=norm_num_groups,
414416
resnet_time_scale_shift=resnet_time_scale_shift,
417+
cross_attention_norm=cross_attention_norm,
415418
)
416419
elif mid_block_type is None:
417420
self.mid_block = None
@@ -460,6 +463,7 @@ def __init__(
460463
only_cross_attention=only_cross_attention[i],
461464
upcast_attention=upcast_attention,
462465
resnet_time_scale_shift=resnet_time_scale_shift,
466+
cross_attention_norm=cross_attention_norm,
463467
)
464468
self.up_blocks.append(up_block)
465469
prev_output_channel = output_channel
@@ -1434,6 +1438,7 @@ def __init__(
14341438
attn_num_head_channels=1,
14351439
output_scale_factor=1.0,
14361440
cross_attention_dim=1280,
1441+
cross_attention_norm=None,
14371442
):
14381443
super().__init__()
14391444

@@ -1472,6 +1477,7 @@ def __init__(
14721477
norm_num_groups=resnet_groups,
14731478
bias=True,
14741479
upcast_softmax=True,
1480+
cross_attention_norm=cross_attention_norm,
14751481
processor=AttnAddedKVProcessor(),
14761482
)
14771483
)

0 commit comments

Comments
 (0)