@@ -191,7 +191,12 @@ class conditioning with `class_embed_type` equal to `None`.
191
191
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
192
192
using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
193
193
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
194
- embeddings with the class embeddings.
194
+ embeddings with the class embeddings.
195
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
196
+ Whether to use cross attention with the mid block when using the `UNetMidBlockFlatSimpleCrossAttn`. If
197
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the
198
+ `only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will
199
+ default to `False`.
195
200
"""
196
201
197
202
_supports_gradient_checkpointing = True
@@ -244,7 +249,7 @@ def __init__(
244
249
conv_out_kernel : int = 3 ,
245
250
projection_class_embeddings_input_dim : Optional [int ] = None ,
246
251
class_embeddings_concat : bool = False ,
247
- mid_block_only_cross_attention : bool = False ,
252
+ mid_block_only_cross_attention : Optional [ bool ] = None ,
248
253
):
249
254
super ().__init__ ()
250
255
@@ -358,8 +363,14 @@ def __init__(
358
363
self .up_blocks = nn .ModuleList ([])
359
364
360
365
if isinstance (only_cross_attention , bool ):
366
+ if mid_block_only_cross_attention is None :
367
+ mid_block_only_cross_attention = only_cross_attention
368
+
361
369
only_cross_attention = [only_cross_attention ] * len (down_block_types )
362
370
371
+ if mid_block_only_cross_attention is None :
372
+ mid_block_only_cross_attention = False
373
+
363
374
if isinstance (attention_head_dim , int ):
364
375
attention_head_dim = (attention_head_dim ,) * len (down_block_types )
365
376
0 commit comments