Skip to content

Commit 3c8b67b

Browse files
authored
Flux: pass joint_attention_kwargs when using gradient_checkpointing (#11814)
Flux: pass joint_attention_kwargs when gradient_checkpointing
1 parent 9feb946 commit 3c8b67b

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,7 @@ def forward(
490490
encoder_hidden_states,
491491
temb,
492492
image_rotary_emb,
493+
joint_attention_kwargs,
493494
)
494495

495496
else:
@@ -521,6 +522,7 @@ def forward(
521522
encoder_hidden_states,
522523
temb,
523524
image_rotary_emb,
525+
joint_attention_kwargs,
524526
)
525527

526528
else:

0 commit comments

Comments
 (0)