22
22
from ..image_processor import IPAdapterMaskProcessor
23
23
from ..utils import deprecate , logging
24
24
from ..utils .import_utils import is_torch_npu_available , is_xformers_available
25
- from ..utils .torch_utils import maybe_allow_in_graph
25
+ from ..utils .torch_utils import is_torch_version , maybe_allow_in_graph
26
26
27
27
28
28
logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
@@ -104,6 +104,7 @@ def __init__(
104
104
cross_attention_norm_num_groups : int = 32 ,
105
105
qk_norm : Optional [str ] = None ,
106
106
added_kv_proj_dim : Optional [int ] = None ,
107
+ added_proj_bias : Optional [bool ] = True ,
107
108
norm_num_groups : Optional [int ] = None ,
108
109
spatial_norm_dim : Optional [int ] = None ,
109
110
out_bias : bool = True ,
@@ -118,6 +119,10 @@ def __init__(
118
119
context_pre_only = None ,
119
120
):
120
121
super ().__init__ ()
122
+
123
+ # To prevent circular import.
124
+ from .normalization import FP32LayerNorm
125
+
121
126
self .inner_dim = out_dim if out_dim is not None else dim_head * heads
122
127
self .inner_kv_dim = self .inner_dim if kv_heads is None else dim_head * kv_heads
123
128
self .query_dim = query_dim
@@ -170,6 +175,9 @@ def __init__(
170
175
elif qk_norm == "layer_norm" :
171
176
self .norm_q = nn .LayerNorm (dim_head , eps = eps )
172
177
self .norm_k = nn .LayerNorm (dim_head , eps = eps )
178
+ elif qk_norm == "fp32_layer_norm" :
179
+ self .norm_q = FP32LayerNorm (dim_head , elementwise_affine = False , bias = False , eps = eps )
180
+ self .norm_k = FP32LayerNorm (dim_head , elementwise_affine = False , bias = False , eps = eps )
173
181
elif qk_norm == "layer_norm_across_heads" :
174
182
# Lumina applys qk norm across all heads
175
183
self .norm_q = nn .LayerNorm (dim_head * heads , eps = eps )
@@ -211,10 +219,10 @@ def __init__(
211
219
self .to_v = None
212
220
213
221
if self .added_kv_proj_dim is not None :
214
- self .add_k_proj = nn .Linear (added_kv_proj_dim , self .inner_kv_dim )
215
- self .add_v_proj = nn .Linear (added_kv_proj_dim , self .inner_kv_dim )
222
+ self .add_k_proj = nn .Linear (added_kv_proj_dim , self .inner_kv_dim , bias = added_proj_bias )
223
+ self .add_v_proj = nn .Linear (added_kv_proj_dim , self .inner_kv_dim , bias = added_proj_bias )
216
224
if self .context_pre_only is not None :
217
- self .add_q_proj = nn .Linear (added_kv_proj_dim , self .inner_dim )
225
+ self .add_q_proj = nn .Linear (added_kv_proj_dim , self .inner_dim , bias = added_proj_bias )
218
226
219
227
self .to_out = nn .ModuleList ([])
220
228
self .to_out .append (nn .Linear (self .inner_dim , self .out_dim , bias = out_bias ))
@@ -223,6 +231,14 @@ def __init__(
223
231
if self .context_pre_only is not None and not self .context_pre_only :
224
232
self .to_add_out = nn .Linear (self .inner_dim , self .out_dim , bias = out_bias )
225
233
234
+ if qk_norm is not None and added_kv_proj_dim is not None :
235
+ if qk_norm == "fp32_layer_norm" :
236
+ self .norm_added_q = FP32LayerNorm (dim_head , elementwise_affine = False , bias = False , eps = eps )
237
+ self .norm_added_k = FP32LayerNorm (dim_head , elementwise_affine = False , bias = False , eps = eps )
238
+ else :
239
+ self .norm_added_q = None
240
+ self .norm_added_k = None
241
+
226
242
# set attention processor
227
243
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
228
244
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
@@ -1137,6 +1153,100 @@ def __call__(
1137
1153
return hidden_states , encoder_hidden_states
1138
1154
1139
1155
1156
+ class AuraFlowAttnProcessor2_0 :
1157
+ """Attention processor used typically in processing Aura Flow."""
1158
+
1159
+ def __init__ (self ):
1160
+ if not hasattr (F , "scaled_dot_product_attention" ) and is_torch_version ("<" , "2.1" ):
1161
+ raise ImportError (
1162
+ "AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
1163
+ )
1164
+
1165
+ def __call__ (
1166
+ self ,
1167
+ attn : Attention ,
1168
+ hidden_states : torch .FloatTensor ,
1169
+ encoder_hidden_states : torch .FloatTensor = None ,
1170
+ i = 0 ,
1171
+ * args ,
1172
+ ** kwargs ,
1173
+ ) -> torch .FloatTensor :
1174
+ batch_size = hidden_states .shape [0 ]
1175
+
1176
+ # `sample` projections.
1177
+ query = attn .to_q (hidden_states )
1178
+ key = attn .to_k (hidden_states )
1179
+ value = attn .to_v (hidden_states )
1180
+
1181
+ # `context` projections.
1182
+ if encoder_hidden_states is not None :
1183
+ encoder_hidden_states_query_proj = attn .add_q_proj (encoder_hidden_states )
1184
+ encoder_hidden_states_key_proj = attn .add_k_proj (encoder_hidden_states )
1185
+ encoder_hidden_states_value_proj = attn .add_v_proj (encoder_hidden_states )
1186
+
1187
+ # Reshape.
1188
+ inner_dim = key .shape [- 1 ]
1189
+ head_dim = inner_dim // attn .heads
1190
+ query = query .view (batch_size , - 1 , attn .heads , head_dim )
1191
+ key = key .view (batch_size , - 1 , attn .heads , head_dim )
1192
+ value = value .view (batch_size , - 1 , attn .heads , head_dim )
1193
+
1194
+ # Apply QK norm.
1195
+ if attn .norm_q is not None :
1196
+ query = attn .norm_q (query )
1197
+ if attn .norm_k is not None :
1198
+ key = attn .norm_k (key )
1199
+
1200
+ # Concatenate the projections.
1201
+ if encoder_hidden_states is not None :
1202
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj .view (
1203
+ batch_size , - 1 , attn .heads , head_dim
1204
+ )
1205
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj .view (batch_size , - 1 , attn .heads , head_dim )
1206
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj .view (
1207
+ batch_size , - 1 , attn .heads , head_dim
1208
+ )
1209
+
1210
+ if attn .norm_added_q is not None :
1211
+ encoder_hidden_states_query_proj = attn .norm_added_q (encoder_hidden_states_query_proj )
1212
+ if attn .norm_added_k is not None :
1213
+ encoder_hidden_states_key_proj = attn .norm_added_q (encoder_hidden_states_key_proj )
1214
+
1215
+ query = torch .cat ([encoder_hidden_states_query_proj , query ], dim = 1 )
1216
+ key = torch .cat ([encoder_hidden_states_key_proj , key ], dim = 1 )
1217
+ value = torch .cat ([encoder_hidden_states_value_proj , value ], dim = 1 )
1218
+
1219
+ query = query .transpose (1 , 2 )
1220
+ key = key .transpose (1 , 2 )
1221
+ value = value .transpose (1 , 2 )
1222
+
1223
+ # Attention.
1224
+ hidden_states = F .scaled_dot_product_attention (
1225
+ query , key , value , dropout_p = 0.0 , scale = attn .scale , is_causal = False
1226
+ )
1227
+ hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
1228
+ hidden_states = hidden_states .to (query .dtype )
1229
+
1230
+ # Split the attention outputs.
1231
+ if encoder_hidden_states is not None :
1232
+ hidden_states , encoder_hidden_states = (
1233
+ hidden_states [:, encoder_hidden_states .shape [1 ] :],
1234
+ hidden_states [:, : encoder_hidden_states .shape [1 ]],
1235
+ )
1236
+
1237
+ # linear proj
1238
+ hidden_states = attn .to_out [0 ](hidden_states )
1239
+ # dropout
1240
+ hidden_states = attn .to_out [1 ](hidden_states )
1241
+ if encoder_hidden_states is not None :
1242
+ encoder_hidden_states = attn .to_add_out (encoder_hidden_states )
1243
+
1244
+ if encoder_hidden_states is not None :
1245
+ return hidden_states , encoder_hidden_states
1246
+ else :
1247
+ return hidden_states
1248
+
1249
+
1140
1250
class XFormersAttnAddedKVProcessor :
1141
1251
r"""
1142
1252
Processor for implementing memory efficient attention using xFormers.
0 commit comments