Skip to content

Commit 07b9368

Browse files
williambermandg845
authored andcommitted
Attention processor cross attention norm group norm (huggingface#3021)
add group norm type to attention processor cross attention norm This lets the cross attention norm use both a group norm block and a layer norm block. The group norm operates along the channels dimension and requires input shape (batch size, channels, *) where as the layer norm with a single `normalized_shape` dimension only operates over the least significant dimension i.e. (*, channels). The channels we want to normalize are the hidden dimension of the encoder hidden states. By convention, the encoder hidden states are always passed as (batch size, sequence length, hidden states). This means the layer norm can operate on the tensor without modification, but the group norm requires flipping the last two dimensions to operate on (batch size, hidden states, sequence length). All existing attention processors will have the same logic and we can consolidate it in a helper function `prepare_encoder_hidden_states` prepare_encoder_hidden_states -> norm_encoder_hidden_states re: @patrickvonplaten move norm_cross defined check to outside norm_encoder_hidden_states add missing attn.norm_cross check
1 parent e38bc51 commit 07b9368

File tree

6 files changed

+96
-21
lines changed

6 files changed

+96
-21
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 68 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ def __init__(
5656
bias=False,
5757
upcast_attention: bool = False,
5858
upcast_softmax: bool = False,
59-
cross_attention_norm: bool = False,
59+
cross_attention_norm: Optional[str] = None,
60+
cross_attention_norm_num_groups: int = 32,
6061
added_kv_proj_dim: Optional[int] = None,
6162
norm_num_groups: Optional[int] = None,
6263
out_bias: bool = True,
@@ -69,7 +70,6 @@ def __init__(
6970
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
7071
self.upcast_attention = upcast_attention
7172
self.upcast_softmax = upcast_softmax
72-
self.cross_attention_norm = cross_attention_norm
7373

7474
self.scale = dim_head**-0.5 if scale_qk else 1.0
7575

@@ -92,8 +92,28 @@ def __init__(
9292
else:
9393
self.group_norm = None
9494

95-
if cross_attention_norm:
95+
if cross_attention_norm is None:
96+
self.norm_cross = None
97+
elif cross_attention_norm == "layer_norm":
9698
self.norm_cross = nn.LayerNorm(cross_attention_dim)
99+
elif cross_attention_norm == "group_norm":
100+
if self.added_kv_proj_dim is not None:
101+
# The given `encoder_hidden_states` are initially of shape
102+
# (batch_size, seq_len, added_kv_proj_dim) before being projected
103+
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
104+
# before the projection, so we need to use `added_kv_proj_dim` as
105+
# the number of channels for the group norm.
106+
norm_cross_num_channels = added_kv_proj_dim
107+
else:
108+
norm_cross_num_channels = cross_attention_dim
109+
110+
self.norm_cross = nn.GroupNorm(
111+
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
112+
)
113+
else:
114+
raise ValueError(
115+
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
116+
)
97117

98118
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
99119

@@ -304,6 +324,25 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None)
304324
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
305325
return attention_mask
306326

327+
def norm_encoder_hidden_states(self, encoder_hidden_states):
328+
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
329+
330+
if isinstance(self.norm_cross, nn.LayerNorm):
331+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
332+
elif isinstance(self.norm_cross, nn.GroupNorm):
333+
# Group norm norms along the channels dimension and expects
334+
# input to be in the shape of (N, C, *). In this case, we want
335+
# to norm along the hidden dimension, so we need to move
336+
# (batch_size, sequence_length, hidden_size) ->
337+
# (batch_size, hidden_size, sequence_length)
338+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
339+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
340+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
341+
else:
342+
assert False
343+
344+
return encoder_hidden_states
345+
307346

308347
class AttnProcessor:
309348
def __call__(
@@ -321,8 +360,8 @@ def __call__(
321360

322361
if encoder_hidden_states is None:
323362
encoder_hidden_states = hidden_states
324-
elif attn.cross_attention_norm:
325-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
363+
elif attn.norm_cross:
364+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
326365

327366
key = attn.to_k(encoder_hidden_states)
328367
value = attn.to_v(encoder_hidden_states)
@@ -388,7 +427,10 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
388427
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
389428
query = attn.head_to_batch_dim(query)
390429

391-
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
430+
if encoder_hidden_states is None:
431+
encoder_hidden_states = hidden_states
432+
elif attn.norm_cross:
433+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
392434

393435
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
394436
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
@@ -416,6 +458,11 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
416458

417459
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
418460

461+
if encoder_hidden_states is None:
462+
encoder_hidden_states = hidden_states
463+
elif attn.norm_cross:
464+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
465+
419466
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
420467

421468
query = attn.to_q(hidden_states)
@@ -467,8 +514,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
467514

468515
if encoder_hidden_states is None:
469516
encoder_hidden_states = hidden_states
470-
elif attn.cross_attention_norm:
471-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
517+
elif attn.norm_cross:
518+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
472519

473520
key = attn.to_k(encoder_hidden_states)
474521
value = attn.to_v(encoder_hidden_states)
@@ -511,8 +558,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
511558

512559
if encoder_hidden_states is None:
513560
encoder_hidden_states = hidden_states
514-
elif attn.cross_attention_norm:
515-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
561+
elif attn.norm_cross:
562+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
516563

517564
key = attn.to_k(encoder_hidden_states)
518565
value = attn.to_v(encoder_hidden_states)
@@ -561,7 +608,10 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
561608
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
562609
query = attn.head_to_batch_dim(query).contiguous()
563610

564-
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
611+
if encoder_hidden_states is None:
612+
encoder_hidden_states = hidden_states
613+
elif attn.norm_cross:
614+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
565615

566616
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
567617
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
@@ -598,8 +648,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
598648

599649
if encoder_hidden_states is None:
600650
encoder_hidden_states = hidden_states
601-
elif attn.cross_attention_norm:
602-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
651+
elif attn.norm_cross:
652+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
603653

604654
key = attn.to_k(encoder_hidden_states)
605655
value = attn.to_v(encoder_hidden_states)
@@ -647,6 +697,11 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
647697

648698
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
649699

700+
if encoder_hidden_states is None:
701+
encoder_hidden_states = hidden_states
702+
elif attn.norm_cross:
703+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
704+
650705
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
651706

652707
query = attn.to_q(hidden_states)

src/diffusers/models/unet_2d_blocks.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def get_down_block(
4444
resnet_time_scale_shift="default",
4545
resnet_skip_time_act=False,
4646
resnet_out_scale_factor=1.0,
47+
cross_attention_norm=None,
4748
):
4849
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
4950
if down_block_type == "DownBlock2D":
@@ -126,6 +127,7 @@ def get_down_block(
126127
skip_time_act=resnet_skip_time_act,
127128
output_scale_factor=resnet_out_scale_factor,
128129
only_cross_attention=only_cross_attention,
130+
cross_attention_norm=cross_attention_norm,
129131
)
130132
elif down_block_type == "SkipDownBlock2D":
131133
return SkipDownBlock2D(
@@ -223,6 +225,7 @@ def get_up_block(
223225
resnet_time_scale_shift="default",
224226
resnet_skip_time_act=False,
225227
resnet_out_scale_factor=1.0,
228+
cross_attention_norm=None,
226229
):
227230
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
228231
if up_block_type == "UpBlock2D":
@@ -293,6 +296,7 @@ def get_up_block(
293296
skip_time_act=resnet_skip_time_act,
294297
output_scale_factor=resnet_out_scale_factor,
295298
only_cross_attention=only_cross_attention,
299+
cross_attention_norm=cross_attention_norm,
296300
)
297301
elif up_block_type == "AttnUpBlock2D":
298302
return AttnUpBlock2D(
@@ -578,6 +582,7 @@ def __init__(
578582
cross_attention_dim=1280,
579583
skip_time_act=False,
580584
only_cross_attention=False,
585+
cross_attention_norm=None,
581586
):
582587
super().__init__()
583588

@@ -618,6 +623,7 @@ def __init__(
618623
bias=True,
619624
upcast_softmax=True,
620625
only_cross_attention=only_cross_attention,
626+
cross_attention_norm=cross_attention_norm,
621627
processor=AttnAddedKVProcessor(),
622628
)
623629
)
@@ -1361,6 +1367,7 @@ def __init__(
13611367
add_downsample=True,
13621368
skip_time_act=False,
13631369
only_cross_attention=False,
1370+
cross_attention_norm=None,
13641371
):
13651372
super().__init__()
13661373

@@ -1400,6 +1407,7 @@ def __init__(
14001407
bias=True,
14011408
upcast_softmax=True,
14021409
only_cross_attention=only_cross_attention,
1410+
cross_attention_norm=cross_attention_norm,
14031411
processor=AttnAddedKVProcessor(),
14041412
)
14051413
)
@@ -1580,7 +1588,7 @@ def __init__(
15801588
temb_channels=temb_channels,
15811589
attention_bias=True,
15821590
add_self_attention=add_self_attention,
1583-
cross_attention_norm=True,
1591+
cross_attention_norm="layer_norm",
15841592
group_size=resnet_group_size,
15851593
)
15861594
)
@@ -2361,6 +2369,7 @@ def __init__(
23612369
add_upsample=True,
23622370
skip_time_act=False,
23632371
only_cross_attention=False,
2372+
cross_attention_norm=None,
23642373
):
23652374
super().__init__()
23662375
resnets = []
@@ -2401,6 +2410,7 @@ def __init__(
24012410
bias=True,
24022411
upcast_softmax=True,
24032412
only_cross_attention=only_cross_attention,
2413+
cross_attention_norm=cross_attention_norm,
24042414
processor=AttnAddedKVProcessor(),
24052415
)
24062416
)
@@ -2608,7 +2618,7 @@ def __init__(
26082618
temb_channels=temb_channels,
26092619
attention_bias=True,
26102620
add_self_attention=add_self_attention,
2611-
cross_attention_norm=True,
2621+
cross_attention_norm="layer_norm",
26122622
upcast_attention=upcast_attention,
26132623
)
26142624
)
@@ -2703,7 +2713,7 @@ def __init__(
27032713
upcast_attention: bool = False,
27042714
temb_channels: int = 768, # for ada_group_norm
27052715
add_self_attention: bool = False,
2706-
cross_attention_norm: bool = False,
2716+
cross_attention_norm: Optional[str] = None,
27072717
group_size: int = 32,
27082718
):
27092719
super().__init__()
@@ -2719,7 +2729,7 @@ def __init__(
27192729
dropout=dropout,
27202730
bias=attention_bias,
27212731
cross_attention_dim=None,
2722-
cross_attention_norm=False,
2732+
cross_attention_norm=None,
27232733
)
27242734

27252735
# 2. Cross-Attn

src/diffusers/models/unet_2d_condition.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def __init__(
169169
projection_class_embeddings_input_dim: Optional[int] = None,
170170
class_embeddings_concat: bool = False,
171171
mid_block_only_cross_attention: Optional[bool] = None,
172+
cross_attention_norm: Optional[str] = None,
172173
):
173174
super().__init__()
174175

@@ -341,6 +342,7 @@ def __init__(
341342
resnet_time_scale_shift=resnet_time_scale_shift,
342343
resnet_skip_time_act=resnet_skip_time_act,
343344
resnet_out_scale_factor=resnet_out_scale_factor,
345+
cross_attention_norm=cross_attention_norm,
344346
)
345347
self.down_blocks.append(down_block)
346348

@@ -373,6 +375,7 @@ def __init__(
373375
resnet_time_scale_shift=resnet_time_scale_shift,
374376
skip_time_act=resnet_skip_time_act,
375377
only_cross_attention=mid_block_only_cross_attention,
378+
cross_attention_norm=cross_attention_norm,
376379
)
377380
elif mid_block_type is None:
378381
self.mid_block = None
@@ -424,6 +427,7 @@ def __init__(
424427
resnet_time_scale_shift=resnet_time_scale_shift,
425428
resnet_skip_time_act=resnet_skip_time_act,
426429
resnet_out_scale_factor=resnet_out_scale_factor,
430+
cross_attention_norm=cross_attention_norm,
427431
)
428432
self.up_blocks.append(up_block)
429433
prev_output_channel = output_channel

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,8 @@ def __call__(
243243

244244
if encoder_hidden_states is None:
245245
encoder_hidden_states = hidden_states
246-
elif attn.cross_attention_norm:
247-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
246+
elif attn.norm_cross:
247+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
248248

249249
key = attn.to_k(encoder_hidden_states)
250250
value = attn.to_v(encoder_hidden_states)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ def __call__(
6565

6666
if encoder_hidden_states is None:
6767
encoder_hidden_states = hidden_states
68-
elif attn.cross_attention_norm:
69-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
68+
elif attn.norm_cross:
69+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
7070

7171
key = attn.to_k(encoder_hidden_states)
7272
value = attn.to_v(encoder_hidden_states)

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ def __init__(
255255
projection_class_embeddings_input_dim: Optional[int] = None,
256256
class_embeddings_concat: bool = False,
257257
mid_block_only_cross_attention: Optional[bool] = None,
258+
cross_attention_norm: Optional[str] = None,
258259
):
259260
super().__init__()
260261

@@ -433,6 +434,7 @@ def __init__(
433434
resnet_time_scale_shift=resnet_time_scale_shift,
434435
resnet_skip_time_act=resnet_skip_time_act,
435436
resnet_out_scale_factor=resnet_out_scale_factor,
437+
cross_attention_norm=cross_attention_norm,
436438
)
437439
self.down_blocks.append(down_block)
438440

@@ -465,6 +467,7 @@ def __init__(
465467
resnet_time_scale_shift=resnet_time_scale_shift,
466468
skip_time_act=resnet_skip_time_act,
467469
only_cross_attention=mid_block_only_cross_attention,
470+
cross_attention_norm=cross_attention_norm,
468471
)
469472
elif mid_block_type is None:
470473
self.mid_block = None
@@ -516,6 +519,7 @@ def __init__(
516519
resnet_time_scale_shift=resnet_time_scale_shift,
517520
resnet_skip_time_act=resnet_skip_time_act,
518521
resnet_out_scale_factor=resnet_out_scale_factor,
522+
cross_attention_norm=cross_attention_norm,
519523
)
520524
self.up_blocks.append(up_block)
521525
prev_output_channel = output_channel
@@ -1511,6 +1515,7 @@ def __init__(
15111515
cross_attention_dim=1280,
15121516
skip_time_act=False,
15131517
only_cross_attention=False,
1518+
cross_attention_norm=None,
15141519
):
15151520
super().__init__()
15161521

@@ -1551,6 +1556,7 @@ def __init__(
15511556
bias=True,
15521557
upcast_softmax=True,
15531558
only_cross_attention=only_cross_attention,
1559+
cross_attention_norm=cross_attention_norm,
15541560
processor=AttnAddedKVProcessor(),
15551561
)
15561562
)

0 commit comments

Comments
 (0)