11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
- import warnings
15
14
from typing import Callable , Optional , Union
16
15
17
16
import torch
@@ -166,7 +165,8 @@ def set_use_memory_efficient_attention_xformers(
166
165
self , use_memory_efficient_attention_xformers : bool , attention_op : Optional [Callable ] = None
167
166
):
168
167
is_lora = hasattr (self , "processor" ) and isinstance (
169
- self .processor , (LoRAAttnProcessor , LoRAXFormersAttnProcessor , LoRAAttnAddedKVProcessor )
168
+ self .processor ,
169
+ (LoRAAttnProcessor , LoRAAttnProcessor2_0 , LoRAXFormersAttnProcessor , LoRAAttnAddedKVProcessor ),
170
170
)
171
171
is_custom_diffusion = hasattr (self , "processor" ) and isinstance (
172
172
self .processor , (CustomDiffusionAttnProcessor , CustomDiffusionXFormersAttnProcessor )
@@ -200,14 +200,6 @@ def set_use_memory_efficient_attention_xformers(
200
200
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
201
201
" only available for GPU "
202
202
)
203
- elif hasattr (F , "scaled_dot_product_attention" ) and self .scale_qk :
204
- warnings .warn (
205
- "You have specified using flash attention using xFormers but you have PyTorch 2.0 already installed. "
206
- "We will default to PyTorch's native efficient flash attention implementation (`F.scaled_dot_product_attention`) "
207
- "introduced in PyTorch 2.0. In case you are using LoRA or Custom Diffusion, we will fall "
208
- "back to their respective attention processors i.e., we will NOT use the PyTorch 2.0 "
209
- "native efficient flash attention."
210
- )
211
203
else :
212
204
try :
213
205
# Make sure we can run the memory efficient attention
@@ -220,6 +212,8 @@ def set_use_memory_efficient_attention_xformers(
220
212
raise e
221
213
222
214
if is_lora :
215
+ # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
216
+ # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
223
217
processor = LoRAXFormersAttnProcessor (
224
218
hidden_size = self .processor .hidden_size ,
225
219
cross_attention_dim = self .processor .cross_attention_dim ,
@@ -252,7 +246,10 @@ def set_use_memory_efficient_attention_xformers(
252
246
processor = XFormersAttnProcessor (attention_op = attention_op )
253
247
else :
254
248
if is_lora :
255
- processor = LoRAAttnProcessor (
249
+ attn_processor_class = (
250
+ LoRAAttnProcessor2_0 if hasattr (F , "scaled_dot_product_attention" ) else LoRAAttnProcessor
251
+ )
252
+ processor = attn_processor_class (
256
253
hidden_size = self .processor .hidden_size ,
257
254
cross_attention_dim = self .processor .cross_attention_dim ,
258
255
rank = self .processor .rank ,
@@ -548,6 +545,8 @@ class LoRAAttnProcessor(nn.Module):
548
545
The number of channels in the `encoder_hidden_states`.
549
546
rank (`int`, defaults to 4):
550
547
The dimension of the LoRA update matrices.
548
+ network_alpha (`int`, *optional*):
549
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
551
550
"""
552
551
553
552
def __init__ (self , hidden_size , cross_attention_dim = None , rank = 4 , network_alpha = None ):
@@ -843,6 +842,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
843
842
The number of channels in the `encoder_hidden_states`.
844
843
rank (`int`, defaults to 4):
845
844
The dimension of the LoRA update matrices.
845
+
846
846
"""
847
847
848
848
def __init__ (self , hidden_size , cross_attention_dim = None , rank = 4 , network_alpha = None ):
@@ -1162,6 +1162,9 @@ class LoRAXFormersAttnProcessor(nn.Module):
1162
1162
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1163
1163
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1164
1164
operator.
1165
+ network_alpha (`int`, *optional*):
1166
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1167
+
1165
1168
"""
1166
1169
1167
1170
def __init__ (
@@ -1236,6 +1239,97 @@ def __call__(
1236
1239
return hidden_states
1237
1240
1238
1241
1242
+ class LoRAAttnProcessor2_0 (nn .Module ):
1243
+ r"""
1244
+ Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product
1245
+ attention.
1246
+
1247
+ Args:
1248
+ hidden_size (`int`):
1249
+ The hidden size of the attention layer.
1250
+ cross_attention_dim (`int`, *optional*):
1251
+ The number of channels in the `encoder_hidden_states`.
1252
+ rank (`int`, defaults to 4):
1253
+ The dimension of the LoRA update matrices.
1254
+ network_alpha (`int`, *optional*):
1255
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1256
+ """
1257
+
1258
+ def __init__ (self , hidden_size , cross_attention_dim = None , rank = 4 , network_alpha = None ):
1259
+ super ().__init__ ()
1260
+ if not hasattr (F , "scaled_dot_product_attention" ):
1261
+ raise ImportError ("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
1262
+
1263
+ self .hidden_size = hidden_size
1264
+ self .cross_attention_dim = cross_attention_dim
1265
+ self .rank = rank
1266
+
1267
+ self .to_q_lora = LoRALinearLayer (hidden_size , hidden_size , rank , network_alpha )
1268
+ self .to_k_lora = LoRALinearLayer (cross_attention_dim or hidden_size , hidden_size , rank , network_alpha )
1269
+ self .to_v_lora = LoRALinearLayer (cross_attention_dim or hidden_size , hidden_size , rank , network_alpha )
1270
+ self .to_out_lora = LoRALinearLayer (hidden_size , hidden_size , rank , network_alpha )
1271
+
1272
+ def __call__ (self , attn : Attention , hidden_states , encoder_hidden_states = None , attention_mask = None , scale = 1.0 ):
1273
+ residual = hidden_states
1274
+
1275
+ input_ndim = hidden_states .ndim
1276
+
1277
+ if input_ndim == 4 :
1278
+ batch_size , channel , height , width = hidden_states .shape
1279
+ hidden_states = hidden_states .view (batch_size , channel , height * width ).transpose (1 , 2 )
1280
+
1281
+ batch_size , sequence_length , _ = (
1282
+ hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
1283
+ )
1284
+ inner_dim = hidden_states .shape [- 1 ]
1285
+
1286
+ if attention_mask is not None :
1287
+ attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
1288
+ # scaled_dot_product_attention expects attention_mask shape to be
1289
+ # (batch, heads, source_length, target_length)
1290
+ attention_mask = attention_mask .view (batch_size , attn .heads , - 1 , attention_mask .shape [- 1 ])
1291
+
1292
+ if attn .group_norm is not None :
1293
+ hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
1294
+
1295
+ query = attn .to_q (hidden_states ) + scale * self .to_q_lora (hidden_states )
1296
+
1297
+ if encoder_hidden_states is None :
1298
+ encoder_hidden_states = hidden_states
1299
+ elif attn .norm_cross :
1300
+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
1301
+
1302
+ key = attn .to_k (encoder_hidden_states ) + scale * self .to_k_lora (encoder_hidden_states )
1303
+ value = attn .to_v (encoder_hidden_states ) + scale * self .to_v_lora (encoder_hidden_states )
1304
+
1305
+ head_dim = inner_dim // attn .heads
1306
+ query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1307
+ key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1308
+ value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1309
+
1310
+ # TODO: add support for attn.scale when we move to Torch 2.1
1311
+ hidden_states = F .scaled_dot_product_attention (
1312
+ query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
1313
+ )
1314
+ hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
1315
+ hidden_states = hidden_states .to (query .dtype )
1316
+
1317
+ # linear proj
1318
+ hidden_states = attn .to_out [0 ](hidden_states ) + scale * self .to_out_lora (hidden_states )
1319
+ # dropout
1320
+ hidden_states = attn .to_out [1 ](hidden_states )
1321
+
1322
+ if input_ndim == 4 :
1323
+ hidden_states = hidden_states .transpose (- 1 , - 2 ).reshape (batch_size , channel , height , width )
1324
+
1325
+ if attn .residual_connection :
1326
+ hidden_states = hidden_states + residual
1327
+
1328
+ hidden_states = hidden_states / attn .rescale_output_factor
1329
+
1330
+ return hidden_states
1331
+
1332
+
1239
1333
class CustomDiffusionXFormersAttnProcessor (nn .Module ):
1240
1334
r"""
1241
1335
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
@@ -1520,6 +1614,7 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
1520
1614
XFormersAttnAddedKVProcessor ,
1521
1615
LoRAAttnProcessor ,
1522
1616
LoRAXFormersAttnProcessor ,
1617
+ LoRAAttnProcessor2_0 ,
1523
1618
LoRAAttnAddedKVProcessor ,
1524
1619
CustomDiffusionAttnProcessor ,
1525
1620
CustomDiffusionXFormersAttnProcessor ,
0 commit comments