Skip to content

Commit 7329ead

Browse files
committed
mid_block_only_cross_attention better default
allow mid_block_only_cross_attention to default to `only_cross_attention` when `only_cross_attention` is given as a single boolean
1 parent 320d288 commit 7329ead

File tree

2 files changed

+26
-4
lines changed

2 files changed

+26
-4
lines changed

src/diffusers/models/unet_2d_condition.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,12 @@ class conditioning with `class_embed_type` equal to `None`.
110110
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
111111
using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
112112
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
113-
embeddings with the class embeddings.
113+
embeddings with the class embeddings.
114+
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
115+
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
116+
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the
117+
`only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will
118+
default to `False`.
114119
"""
115120

116121
_supports_gradient_checkpointing = True
@@ -158,7 +163,7 @@ def __init__(
158163
conv_out_kernel: int = 3,
159164
projection_class_embeddings_input_dim: Optional[int] = None,
160165
class_embeddings_concat: bool = False,
161-
mid_block_only_cross_attention: bool = False,
166+
mid_block_only_cross_attention: Optional[bool] = None,
162167
):
163168
super().__init__()
164169

@@ -266,8 +271,14 @@ def __init__(
266271
self.up_blocks = nn.ModuleList([])
267272

268273
if isinstance(only_cross_attention, bool):
274+
if mid_block_only_cross_attention is None:
275+
mid_block_only_cross_attention = only_cross_attention
276+
269277
only_cross_attention = [only_cross_attention] * len(down_block_types)
270278

279+
if mid_block_only_cross_attention is None:
280+
mid_block_only_cross_attention = False
281+
271282
if isinstance(attention_head_dim, int):
272283
attention_head_dim = (attention_head_dim,) * len(down_block_types)
273284

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,12 @@ class conditioning with `class_embed_type` equal to `None`.
191191
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
192192
using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
193193
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
194-
embeddings with the class embeddings.
194+
embeddings with the class embeddings.
195+
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
196+
Whether to use cross attention with the mid block when using the `UNetMidBlockFlatSimpleCrossAttn`. If
197+
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the
198+
`only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will
199+
default to `False`.
195200
"""
196201

197202
_supports_gradient_checkpointing = True
@@ -244,7 +249,7 @@ def __init__(
244249
conv_out_kernel: int = 3,
245250
projection_class_embeddings_input_dim: Optional[int] = None,
246251
class_embeddings_concat: bool = False,
247-
mid_block_only_cross_attention: bool = False,
252+
mid_block_only_cross_attention: Optional[bool] = None,
248253
):
249254
super().__init__()
250255

@@ -358,8 +363,14 @@ def __init__(
358363
self.up_blocks = nn.ModuleList([])
359364

360365
if isinstance(only_cross_attention, bool):
366+
if mid_block_only_cross_attention is None:
367+
mid_block_only_cross_attention = only_cross_attention
368+
361369
only_cross_attention = [only_cross_attention] * len(down_block_types)
362370

371+
if mid_block_only_cross_attention is None:
372+
mid_block_only_cross_attention = False
373+
363374
if isinstance(attention_head_dim, int):
364375
attention_head_dim = (attention_head_dim,) * len(down_block_types)
365376

0 commit comments

Comments
 (0)