@@ -56,7 +56,8 @@ def __init__(
56
56
bias = False ,
57
57
upcast_attention : bool = False ,
58
58
upcast_softmax : bool = False ,
59
- cross_attention_norm : bool = False ,
59
+ cross_attention_norm : Optional [str ] = None ,
60
+ cross_attention_norm_num_groups : int = 32 ,
60
61
added_kv_proj_dim : Optional [int ] = None ,
61
62
norm_num_groups : Optional [int ] = None ,
62
63
out_bias : bool = True ,
@@ -69,7 +70,6 @@ def __init__(
69
70
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
70
71
self .upcast_attention = upcast_attention
71
72
self .upcast_softmax = upcast_softmax
72
- self .cross_attention_norm = cross_attention_norm
73
73
74
74
self .scale = dim_head ** - 0.5 if scale_qk else 1.0
75
75
@@ -92,8 +92,28 @@ def __init__(
92
92
else :
93
93
self .group_norm = None
94
94
95
- if cross_attention_norm :
95
+ if cross_attention_norm is None :
96
+ self .norm_cross = None
97
+ elif cross_attention_norm == "layer_norm" :
96
98
self .norm_cross = nn .LayerNorm (cross_attention_dim )
99
+ elif cross_attention_norm == "group_norm" :
100
+ if self .added_kv_proj_dim is not None :
101
+ # The given `encoder_hidden_states` are initially of shape
102
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
103
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
104
+ # before the projection, so we need to use `added_kv_proj_dim` as
105
+ # the number of channels for the group norm.
106
+ norm_cross_num_channels = added_kv_proj_dim
107
+ else :
108
+ norm_cross_num_channels = cross_attention_dim
109
+
110
+ self .norm_cross = nn .GroupNorm (
111
+ num_channels = norm_cross_num_channels , num_groups = cross_attention_norm_num_groups , eps = 1e-5 , affine = True
112
+ )
113
+ else :
114
+ raise ValueError (
115
+ f"unknown cross_attention_norm: { cross_attention_norm } . Should be None, 'layer_norm' or 'group_norm'"
116
+ )
97
117
98
118
self .to_q = nn .Linear (query_dim , inner_dim , bias = bias )
99
119
@@ -304,6 +324,25 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None)
304
324
attention_mask = attention_mask .repeat_interleave (head_size , dim = 0 )
305
325
return attention_mask
306
326
327
+ def norm_encoder_hidden_states (self , encoder_hidden_states ):
328
+ assert self .norm_cross is not None , "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
329
+
330
+ if isinstance (self .norm_cross , nn .LayerNorm ):
331
+ encoder_hidden_states = self .norm_cross (encoder_hidden_states )
332
+ elif isinstance (self .norm_cross , nn .GroupNorm ):
333
+ # Group norm norms along the channels dimension and expects
334
+ # input to be in the shape of (N, C, *). In this case, we want
335
+ # to norm along the hidden dimension, so we need to move
336
+ # (batch_size, sequence_length, hidden_size) ->
337
+ # (batch_size, hidden_size, sequence_length)
338
+ encoder_hidden_states = encoder_hidden_states .transpose (1 , 2 )
339
+ encoder_hidden_states = self .norm_cross (encoder_hidden_states )
340
+ encoder_hidden_states = encoder_hidden_states .transpose (1 , 2 )
341
+ else :
342
+ assert False
343
+
344
+ return encoder_hidden_states
345
+
307
346
308
347
class AttnProcessor :
309
348
def __call__ (
@@ -321,8 +360,8 @@ def __call__(
321
360
322
361
if encoder_hidden_states is None :
323
362
encoder_hidden_states = hidden_states
324
- elif attn .cross_attention_norm :
325
- encoder_hidden_states = attn .norm_cross (encoder_hidden_states )
363
+ elif attn .norm_cross :
364
+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
326
365
327
366
key = attn .to_k (encoder_hidden_states )
328
367
value = attn .to_v (encoder_hidden_states )
@@ -388,7 +427,10 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
388
427
query = attn .to_q (hidden_states ) + scale * self .to_q_lora (hidden_states )
389
428
query = attn .head_to_batch_dim (query )
390
429
391
- encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
430
+ if encoder_hidden_states is None :
431
+ encoder_hidden_states = hidden_states
432
+ elif attn .norm_cross :
433
+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
392
434
393
435
key = attn .to_k (encoder_hidden_states ) + scale * self .to_k_lora (encoder_hidden_states )
394
436
value = attn .to_v (encoder_hidden_states ) + scale * self .to_v_lora (encoder_hidden_states )
@@ -416,6 +458,11 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
416
458
417
459
attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
418
460
461
+ if encoder_hidden_states is None :
462
+ encoder_hidden_states = hidden_states
463
+ elif attn .norm_cross :
464
+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
465
+
419
466
hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
420
467
421
468
query = attn .to_q (hidden_states )
@@ -467,8 +514,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
467
514
468
515
if encoder_hidden_states is None :
469
516
encoder_hidden_states = hidden_states
470
- elif attn .cross_attention_norm :
471
- encoder_hidden_states = attn .norm_cross (encoder_hidden_states )
517
+ elif attn .norm_cross :
518
+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
472
519
473
520
key = attn .to_k (encoder_hidden_states )
474
521
value = attn .to_v (encoder_hidden_states )
@@ -511,8 +558,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
511
558
512
559
if encoder_hidden_states is None :
513
560
encoder_hidden_states = hidden_states
514
- elif attn .cross_attention_norm :
515
- encoder_hidden_states = attn .norm_cross (encoder_hidden_states )
561
+ elif attn .norm_cross :
562
+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
516
563
517
564
key = attn .to_k (encoder_hidden_states )
518
565
value = attn .to_v (encoder_hidden_states )
@@ -561,7 +608,10 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
561
608
query = attn .to_q (hidden_states ) + scale * self .to_q_lora (hidden_states )
562
609
query = attn .head_to_batch_dim (query ).contiguous ()
563
610
564
- encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
611
+ if encoder_hidden_states is None :
612
+ encoder_hidden_states = hidden_states
613
+ elif attn .norm_cross :
614
+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
565
615
566
616
key = attn .to_k (encoder_hidden_states ) + scale * self .to_k_lora (encoder_hidden_states )
567
617
value = attn .to_v (encoder_hidden_states ) + scale * self .to_v_lora (encoder_hidden_states )
@@ -598,8 +648,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
598
648
599
649
if encoder_hidden_states is None :
600
650
encoder_hidden_states = hidden_states
601
- elif attn .cross_attention_norm :
602
- encoder_hidden_states = attn .norm_cross (encoder_hidden_states )
651
+ elif attn .norm_cross :
652
+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
603
653
604
654
key = attn .to_k (encoder_hidden_states )
605
655
value = attn .to_v (encoder_hidden_states )
@@ -647,6 +697,11 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
647
697
648
698
attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
649
699
700
+ if encoder_hidden_states is None :
701
+ encoder_hidden_states = hidden_states
702
+ elif attn .norm_cross :
703
+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
704
+
650
705
hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
651
706
652
707
query = attn .to_q (hidden_states )
0 commit comments