@@ -56,7 +56,8 @@ def __init__(
56
56
bias = False ,
57
57
upcast_attention : bool = False ,
58
58
upcast_softmax : bool = False ,
59
- cross_attention_norm : bool = False ,
59
+ cross_attention_norm : Optional [str ] = None ,
60
+ cross_attention_norm_num_groups : int = 32 ,
60
61
added_kv_proj_dim : Optional [int ] = None ,
61
62
norm_num_groups : Optional [int ] = None ,
62
63
out_bias : bool = True ,
@@ -68,7 +69,6 @@ def __init__(
68
69
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
69
70
self .upcast_attention = upcast_attention
70
71
self .upcast_softmax = upcast_softmax
71
- self .cross_attention_norm = cross_attention_norm
72
72
73
73
self .scale = dim_head ** - 0.5 if scale_qk else 1.0
74
74
@@ -85,8 +85,28 @@ def __init__(
85
85
else :
86
86
self .group_norm = None
87
87
88
- if cross_attention_norm :
88
+ if cross_attention_norm is None :
89
+ self .norm_cross = None
90
+ elif cross_attention_norm == "layer_norm" :
89
91
self .norm_cross = nn .LayerNorm (cross_attention_dim )
92
+ elif cross_attention_norm == "group_norm" :
93
+ if self .added_kv_proj_dim is not None :
94
+ # The given `encoder_hidden_states` are initially of shape
95
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
96
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
97
+ # before the projection, so we need to use `added_kv_proj_dim` as
98
+ # the number of channels for the group norm.
99
+ norm_cross_num_channels = added_kv_proj_dim
100
+ else :
101
+ norm_cross_num_channels = cross_attention_dim
102
+
103
+ self .norm_cross = nn .GroupNorm (
104
+ num_channels = norm_cross_num_channels , num_groups = cross_attention_norm_num_groups , eps = 1e-5 , affine = True
105
+ )
106
+ else :
107
+ raise ValueError (
108
+ f"unknown cross_attention_norm: { cross_attention_norm } . Should be None, 'layer_norm' or 'group_norm'"
109
+ )
90
110
91
111
self .to_q = nn .Linear (query_dim , inner_dim , bias = bias )
92
112
self .to_k = nn .Linear (cross_attention_dim , inner_dim , bias = bias )
@@ -291,6 +311,29 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None)
291
311
attention_mask = attention_mask .repeat_interleave (head_size , dim = 0 )
292
312
return attention_mask
293
313
314
+ def prepare_encoder_hidden_states (self , hidden_states , encoder_hidden_states = None ):
315
+ if encoder_hidden_states is None :
316
+ return hidden_states
317
+
318
+ if self .norm_cross is None :
319
+ return encoder_hidden_states
320
+
321
+ if isinstance (self .norm_cross , nn .LayerNorm ):
322
+ encoder_hidden_states = self .norm_cross (encoder_hidden_states )
323
+ elif isinstance (self .norm_cross , nn .GroupNorm ):
324
+ # Group norm norms along the channels dimension and expects
325
+ # input to be in the shape of (N, C, *). In this case, we want
326
+ # to norm along the hidden dimension, so we need to move
327
+ # (batch_size, sequence_length, hidden_size) ->
328
+ # (batch_size, hidden_size, sequence_length)
329
+ encoder_hidden_states = encoder_hidden_states .transpose (1 , 2 )
330
+ encoder_hidden_states = self .norm_cross (encoder_hidden_states )
331
+ encoder_hidden_states = encoder_hidden_states .transpose (1 , 2 )
332
+ else :
333
+ assert False
334
+
335
+ return encoder_hidden_states
336
+
294
337
295
338
class AttnProcessor :
296
339
def __call__ (
@@ -306,10 +349,7 @@ def __call__(
306
349
attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
307
350
query = attn .to_q (hidden_states )
308
351
309
- if encoder_hidden_states is None :
310
- encoder_hidden_states = hidden_states
311
- elif attn .cross_attention_norm :
312
- encoder_hidden_states = attn .norm_cross (encoder_hidden_states )
352
+ encoder_hidden_states = attn .prepare_encoder_hidden_states (hidden_states , encoder_hidden_states )
313
353
314
354
key = attn .to_k (encoder_hidden_states )
315
355
value = attn .to_v (encoder_hidden_states )
@@ -375,7 +415,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
375
415
query = attn .to_q (hidden_states ) + scale * self .to_q_lora (hidden_states )
376
416
query = attn .head_to_batch_dim (query )
377
417
378
- encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
418
+ encoder_hidden_states = attn . prepare_encoder_hidden_states ( hidden_states , encoder_hidden_states )
379
419
380
420
key = attn .to_k (encoder_hidden_states ) + scale * self .to_k_lora (encoder_hidden_states )
381
421
value = attn .to_v (encoder_hidden_states ) + scale * self .to_v_lora (encoder_hidden_states )
@@ -402,6 +442,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
402
442
batch_size , sequence_length , _ = hidden_states .shape
403
443
404
444
attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
445
+ encoder_hidden_states = attn .prepare_encoder_hidden_states (hidden_states , encoder_hidden_states )
405
446
406
447
hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
407
448
@@ -449,10 +490,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
449
490
450
491
query = attn .to_q (hidden_states )
451
492
452
- if encoder_hidden_states is None :
453
- encoder_hidden_states = hidden_states
454
- elif attn .cross_attention_norm :
455
- encoder_hidden_states = attn .norm_cross (encoder_hidden_states )
493
+ encoder_hidden_states = attn .prepare_encoder_hidden_states (hidden_states , encoder_hidden_states )
456
494
457
495
key = attn .to_k (encoder_hidden_states )
458
496
value = attn .to_v (encoder_hidden_states )
@@ -493,10 +531,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
493
531
494
532
query = attn .to_q (hidden_states )
495
533
496
- if encoder_hidden_states is None :
497
- encoder_hidden_states = hidden_states
498
- elif attn .cross_attention_norm :
499
- encoder_hidden_states = attn .norm_cross (encoder_hidden_states )
534
+ encoder_hidden_states = attn .prepare_encoder_hidden_states (hidden_states , encoder_hidden_states )
500
535
501
536
key = attn .to_k (encoder_hidden_states )
502
537
value = attn .to_v (encoder_hidden_states )
@@ -545,7 +580,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
545
580
query = attn .to_q (hidden_states ) + scale * self .to_q_lora (hidden_states )
546
581
query = attn .head_to_batch_dim (query ).contiguous ()
547
582
548
- encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
583
+ encoder_hidden_states = attn . prepare_encoder_hidden_states ( hidden_states , encoder_hidden_states )
549
584
550
585
key = attn .to_k (encoder_hidden_states ) + scale * self .to_k_lora (encoder_hidden_states )
551
586
value = attn .to_v (encoder_hidden_states ) + scale * self .to_v_lora (encoder_hidden_states )
@@ -580,10 +615,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
580
615
dim = query .shape [- 1 ]
581
616
query = attn .head_to_batch_dim (query )
582
617
583
- if encoder_hidden_states is None :
584
- encoder_hidden_states = hidden_states
585
- elif attn .cross_attention_norm :
586
- encoder_hidden_states = attn .norm_cross (encoder_hidden_states )
618
+ encoder_hidden_states = attn .prepare_encoder_hidden_states (hidden_states , encoder_hidden_states )
587
619
588
620
key = attn .to_k (encoder_hidden_states )
589
621
value = attn .to_v (encoder_hidden_states )
@@ -630,6 +662,7 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
630
662
batch_size , sequence_length , _ = hidden_states .shape
631
663
632
664
attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
665
+ encoder_hidden_states = attn .prepare_encoder_hidden_states (hidden_states , encoder_hidden_states )
633
666
634
667
hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
635
668
0 commit comments