Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
b547fcf
Fix QwenImage txt_seq_lens handling
kashif Nov 23, 2025
72a80c6
formatting
kashif Nov 23, 2025
88cee8b
formatting
kashif Nov 23, 2025
ac5ac24
remove txt_seq_lens and use bool mask
kashif Nov 29, 2025
0477526
Merge branch 'main' into txt_seq_lens
kashif Nov 29, 2025
18efdde
use compute_text_seq_len_from_mask
kashif Nov 30, 2025
6a549d4
add seq_lens to dispatch_attention_fn
kashif Nov 30, 2025
2d424e0
use joint_seq_lens
kashif Nov 30, 2025
30b5f98
remove unused index_block
kashif Nov 30, 2025
588dc04
Merge branch 'main' into txt_seq_lens
kashif Dec 6, 2025
f1c2d99
WIP: Remove seq_lens parameter and use mask-based approach
kashif Dec 6, 2025
ec52417
Merge branch 'txt_seq_lens' of https://github.com/kashif/diffusers in…
kashif Dec 6, 2025
beeb020
fix formatting
kashif Dec 7, 2025
5c6f8e3
undo sage changes
kashif Dec 7, 2025
5d434f6
xformers support
kashif Dec 7, 2025
71ba603
hub fix
kashif Dec 8, 2025
babf490
Merge branch 'main' into txt_seq_lens
kashif Dec 8, 2025
afad335
fix torch compile issues
kashif Dec 8, 2025
2d5ab16
Merge branch 'main' into txt_seq_lens
sayakpaul Dec 9, 2025
c78a1e9
fix tests
kashif Dec 9, 2025
d6d4b1d
use _prepare_attn_mask_native
kashif Dec 9, 2025
e999b76
proper deprecation notice
kashif Dec 9, 2025
8115f0b
add deprecate to txt_seq_lens
kashif Dec 9, 2025
3b1510c
Update src/diffusers/models/transformers/transformer_qwenimage.py
kashif Dec 10, 2025
3676d8e
Update src/diffusers/models/transformers/transformer_qwenimage.py
kashif Dec 10, 2025
9ed0ffd
Only create the mask if there's actual padding
kashif Dec 10, 2025
abec461
Merge branch 'main' into txt_seq_lens
kashif Dec 10, 2025
e26e7b3
fix order of docstrings
kashif Dec 10, 2025
59e3882
Adds performance benchmarks and optimization details for QwenImage
cdutr Dec 11, 2025
0cb2138
Merge branch 'main' into txt_seq_lens
kashif Dec 12, 2025
60bd454
rope_text_seq_len = text_seq_len
kashif Dec 12, 2025
a5abbb8
rename to max_txt_seq_len
kashif Dec 12, 2025
8415c57
Merge branch 'main' into txt_seq_lens
kashif Dec 15, 2025
afff5b7
Merge branch 'main' into txt_seq_lens
kashif Dec 17, 2025
8dc6c3f
Merge branch 'main' into txt_seq_lens
kashif Dec 17, 2025
22cb03d
removed deprecated args
kashif Dec 17, 2025
125a3a4
undo unrelated change
kashif Dec 17, 2025
b5b6342
Updates QwenImage performance documentation
cdutr Dec 17, 2025
61f5265
Updates deprecation warnings for txt_seq_lens parameter
cdutr Dec 17, 2025
2ef38e2
fix compile
kashif Dec 17, 2025
270c63f
Merge branch 'txt_seq_lens' of https://github.com/kashif/diffusers in…
kashif Dec 17, 2025
35efa06
formatting
kashif Dec 17, 2025
50c4815
fix compile tests
kashif Dec 17, 2025
c88bc06
Merge branch 'main' into txt_seq_lens
kashif Dec 17, 2025
1433783
rename helper
kashif Dec 17, 2025
8de799c
remove duplicate
kashif Dec 17, 2025
fc93747
smaller values
kashif Dec 18, 2025
8bb47d8
Merge branch 'main' into txt_seq_lens
kashif Dec 19, 2025
b7c288a
removed
kashif Dec 20, 2025
4700b7f
Merge branch 'main' into txt_seq_lens
kashif Dec 20, 2025
4fe7659
use torch.cond for torch compile
kashif Dec 21, 2025
77902bc
Construct joint attention mask once
kashif Dec 21, 2025
5b570c7
test different backends
kashif Dec 21, 2025
4d4e5f4
Merge branch 'main' into txt_seq_lens
kashif Dec 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions examples/dreambooth/train_dreambooth_lora_qwen_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1513,14 +1513,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
height=model_input.shape[3],
width=model_input.shape[4],
)
print(f"{prompt_embeds_mask.sum(dim=1).tolist()=}")
model_pred = transformer(
hidden_states=packed_noisy_model_input,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
timestep=timesteps / 1000,
img_shapes=img_shapes,
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
return_dict=False,
)[0]
model_pred = QwenImagePipeline._unpack_latents(
Expand Down
70 changes: 52 additions & 18 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def dispatch_attention_fn(
*,
backend: Optional[AttentionBackendName] = None,
parallel_config: Optional["ParallelConfig"] = None,
seq_lens: Optional[torch.Tensor] = None,
) -> torch.Tensor:
attention_kwargs = attention_kwargs or {}

Expand All @@ -327,6 +328,8 @@ def dispatch_attention_fn(
**attention_kwargs,
"_parallel_config": parallel_config,
}
if seq_lens is not None:
kwargs["seq_lens"] = seq_lens
if is_torch_version(">=", "2.5.0"):
kwargs["enable_gqa"] = enable_gqa

Expand Down Expand Up @@ -1400,18 +1403,29 @@ def _flash_varlen_attention(
is_causal: bool = False,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
seq_lens: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape

if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
if seq_lens is not None:
seq_lens = seq_lens.to(query.device)
# use the same lengths for Q and KV
seqlens_k = seq_lens
cu_seqlens_q = torch.cat([seq_lens.new_zeros(1), seq_lens.cumsum(0)], dim=0).to(torch.int32)
cu_seqlens_k = cu_seqlens_q
max_seqlen_q = int(seq_lens.max().item())
max_seqlen_k = max_seqlen_q
attn_mask = None # varlen uses lengths
else:
if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)

(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
)
)

key_valid, value_valid = [], []
for b in range(batch_size):
Expand Down Expand Up @@ -1521,18 +1535,28 @@ def _flash_varlen_attention_3(
is_causal: bool = False,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
seq_lens: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape

if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
if seq_lens is not None:
seq_lens = seq_lens.to(query.device)
seqlens_k = seq_lens
cu_seqlens_q = torch.cat([seq_lens.new_zeros(1), seq_lens.cumsum(0)], dim=0).to(torch.int32)
cu_seqlens_k = cu_seqlens_q
max_seqlen_q = int(seq_lens.max().item())
max_seqlen_k = max_seqlen_q
attn_mask = None # varlen uses lengths
else:
if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)

(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
)
)

key_valid, value_valid = [], []
for b in range(batch_size):
Expand Down Expand Up @@ -2023,21 +2047,31 @@ def _sage_varlen_attention(
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
seq_lens: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if return_lse:
raise ValueError("Sage varlen backend does not support setting `return_lse=True`.")

batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape

if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
if seq_lens is not None:
seq_lens = seq_lens.to(query.device)
seqlens_k = seq_lens
cu_seqlens_q = torch.cat([seq_lens.new_zeros(1), seq_lens.cumsum(0)], dim=0).to(torch.int32)
cu_seqlens_k = cu_seqlens_q
max_seqlen_q = int(seq_lens.max().item())
max_seqlen_k = max_seqlen_q
attn_mask = None # varlen uses lengths
else:
if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)

(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
)
)

key_valid, value_valid = [], []
for b in range(batch_size):
Expand Down
36 changes: 22 additions & 14 deletions src/diffusers/models/controlnets/controlnet_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
QwenImageTransformerBlock,
QwenTimestepProjEmbeddings,
RMSNorm,
compute_text_seq_len_from_mask,
)


Expand Down Expand Up @@ -189,12 +190,11 @@ def forward(
encoder_hidden_states_mask: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_shapes: Optional[List[Tuple[int, int, int]]] = None,
txt_seq_lens: Optional[List[int]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
The [`QwenImageControlNetModel`] forward method.

Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Expand All @@ -205,28 +205,30 @@ def forward(
The scale factor for ControlNet outputs.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*):
Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens.
Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern
(not just contiguous valid tokens followed by padding) since it's applied element-wise in attention.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
img_shapes (`List[Tuple[int, int, int]]`, *optional*):
Image shapes for RoPE computation.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.

Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
If `return_dict` is True, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a `tuple` where
the first element is the controlnet block samples.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
joint_attention_kwargs = {}
lora_scale = 1.0

if USE_PEFT_BACKEND:
Expand All @@ -244,14 +246,22 @@ def forward(

temb = self.time_text_embed(timestep, hidden_states)

image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
# Use the encoder_hidden_states sequence length for RoPE computation and normalize mask
text_seq_len, text_seq_lens_per_sample, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
encoder_hidden_states, encoder_hidden_states_mask
)

if text_seq_lens_per_sample is not None:
joint_attention_kwargs.setdefault("text_seq_lens", text_seq_lens_per_sample)

image_rotary_emb = self.pos_embed(img_shapes, text_seq_len, device=hidden_states.device)

timestep = timestep.to(hidden_states.dtype)
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)

block_samples = ()
for index_block, block in enumerate(self.transformer_blocks):
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
Expand Down Expand Up @@ -321,7 +331,6 @@ def forward(
encoder_hidden_states_mask: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_shapes: Optional[List[Tuple[int, int, int]]] = None,
txt_seq_lens: Optional[List[int]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[QwenImageControlNetOutput, Tuple]:
Expand All @@ -339,7 +348,6 @@ def forward(
encoder_hidden_states_mask=encoder_hidden_states_mask,
timestep=timestep,
img_shapes=img_shapes,
txt_seq_lens=txt_seq_lens,
joint_attention_kwargs=joint_attention_kwargs,
return_dict=return_dict,
)
Expand Down
Loading