Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
b547fcf
Fix QwenImage txt_seq_lens handling
kashif Nov 23, 2025
72a80c6
formatting
kashif Nov 23, 2025
88cee8b
formatting
kashif Nov 23, 2025
ac5ac24
remove txt_seq_lens and use bool mask
kashif Nov 29, 2025
0477526
Merge branch 'main' into txt_seq_lens
kashif Nov 29, 2025
18efdde
use compute_text_seq_len_from_mask
kashif Nov 30, 2025
6a549d4
add seq_lens to dispatch_attention_fn
kashif Nov 30, 2025
2d424e0
use joint_seq_lens
kashif Nov 30, 2025
30b5f98
remove unused index_block
kashif Nov 30, 2025
588dc04
Merge branch 'main' into txt_seq_lens
kashif Dec 6, 2025
f1c2d99
WIP: Remove seq_lens parameter and use mask-based approach
kashif Dec 6, 2025
ec52417
Merge branch 'txt_seq_lens' of https://github.com/kashif/diffusers in…
kashif Dec 6, 2025
beeb020
fix formatting
kashif Dec 7, 2025
5c6f8e3
undo sage changes
kashif Dec 7, 2025
5d434f6
xformers support
kashif Dec 7, 2025
71ba603
hub fix
kashif Dec 8, 2025
babf490
Merge branch 'main' into txt_seq_lens
kashif Dec 8, 2025
afad335
fix torch compile issues
kashif Dec 8, 2025
2d5ab16
Merge branch 'main' into txt_seq_lens
sayakpaul Dec 9, 2025
c78a1e9
fix tests
kashif Dec 9, 2025
d6d4b1d
use _prepare_attn_mask_native
kashif Dec 9, 2025
e999b76
proper deprecation notice
kashif Dec 9, 2025
8115f0b
add deprecate to txt_seq_lens
kashif Dec 9, 2025
3b1510c
Update src/diffusers/models/transformers/transformer_qwenimage.py
kashif Dec 10, 2025
3676d8e
Update src/diffusers/models/transformers/transformer_qwenimage.py
kashif Dec 10, 2025
9ed0ffd
Only create the mask if there's actual padding
kashif Dec 10, 2025
abec461
Merge branch 'main' into txt_seq_lens
kashif Dec 10, 2025
e26e7b3
fix order of docstrings
kashif Dec 10, 2025
59e3882
Adds performance benchmarks and optimization details for QwenImage
cdutr Dec 11, 2025
0cb2138
Merge branch 'main' into txt_seq_lens
kashif Dec 12, 2025
60bd454
rope_text_seq_len = text_seq_len
kashif Dec 12, 2025
a5abbb8
rename to max_txt_seq_len
kashif Dec 12, 2025
8415c57
Merge branch 'main' into txt_seq_lens
kashif Dec 15, 2025
afff5b7
Merge branch 'main' into txt_seq_lens
kashif Dec 17, 2025
8dc6c3f
Merge branch 'main' into txt_seq_lens
kashif Dec 17, 2025
22cb03d
removed deprecated args
kashif Dec 17, 2025
125a3a4
undo unrelated change
kashif Dec 17, 2025
b5b6342
Updates QwenImage performance documentation
cdutr Dec 17, 2025
61f5265
Updates deprecation warnings for txt_seq_lens parameter
cdutr Dec 17, 2025
2ef38e2
fix compile
kashif Dec 17, 2025
270c63f
Merge branch 'txt_seq_lens' of https://github.com/kashif/diffusers in…
kashif Dec 17, 2025
35efa06
formatting
kashif Dec 17, 2025
50c4815
fix compile tests
kashif Dec 17, 2025
c88bc06
Merge branch 'main' into txt_seq_lens
kashif Dec 17, 2025
1433783
rename helper
kashif Dec 17, 2025
8de799c
remove duplicate
kashif Dec 17, 2025
fc93747
smaller values
kashif Dec 18, 2025
8bb47d8
Merge branch 'main' into txt_seq_lens
kashif Dec 19, 2025
b7c288a
removed
kashif Dec 20, 2025
4700b7f
Merge branch 'main' into txt_seq_lens
kashif Dec 20, 2025
2f86879
split attention
dxqb Dec 21, 2025
87bbde4
fix type hints
dxqb Dec 21, 2025
66056f1
fix error if no attn mask is passed
dxqb Dec 21, 2025
0a713d1
Merge branch 'main' into split_attention
dxqb Dec 23, 2025
b9880f6
Merge remote-tracking branch 'origin/main' into pr-12702-base
dxqb Dec 23, 2025
a8bba06
Merge branch 'pr-12702-base' into split_attention
dxqb Dec 23, 2025
5eef3ef
check attention mask
dxqb Dec 26, 2025
e593603
Merge branch 'check_attn_mask' into split_attention
dxqb Dec 26, 2025
23e7a65
Merge branch 'main' into pr-12702-base
dxqb Dec 26, 2025
0584542
Merge branch 'pr-12702-base' into split_attention
dxqb Dec 26, 2025
7651363
more backends
dxqb Dec 26, 2025
cc134a7
bugfix
dxqb Dec 26, 2025
7e456cd
bugfix
dxqb Dec 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions docs/source/en/api/pipelines/qwenimage.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,46 @@ pipe = QwenImageEditPlusPipeline.from_pretrained(
image_1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/grumpy.jpg")
image_2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peng.png")
image = pipe(
image=[image_1, image_2],
prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''',
image=[image_1, image_2],
prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''',
num_inference_steps=50
).images[0]
```

## Performance

### torch.compile

Using `torch.compile` on the transformer provides ~2.4x speedup (A100 80GB: 4.70s → 1.93s):

```python
import torch
from diffusers import QwenImagePipeline

pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16).to("cuda")
pipe.transformer = torch.compile(pipe.transformer)

# First call triggers compilation (~7s overhead)
# Subsequent calls run at ~2.4x faster
image = pipe("a cat", num_inference_steps=50).images[0]
```

### Batched Inference with Variable-Length Prompts

When using classifier-free guidance (CFG) with prompts of different lengths, the pipeline properly handles padding through attention masking. This ensures padding tokens do not influence the generated output.

```python
# CFG with different prompt lengths works correctly
image = pipe(
prompt="A cat",
negative_prompt="blurry, low quality, distorted",
true_cfg_scale=3.5,
num_inference_steps=50,
).images[0]
```

For detailed benchmark scripts and results, see [this gist](https://gist.github.com/cdutr/bea337e4680268168550292d7819dc2f).

## QwenImagePipeline

[[autodoc]] QwenImagePipeline
Expand Down
2 changes: 0 additions & 2 deletions examples/dreambooth/train_dreambooth_lora_qwen_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1513,14 +1513,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
height=model_input.shape[3],
width=model_input.shape[4],
)
print(f"{prompt_embeds_mask.sum(dim=1).tolist()=}")
model_pred = transformer(
hidden_states=packed_noisy_model_input,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
timestep=timesteps / 1000,
img_shapes=img_shapes,
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
return_dict=False,
)[0]
model_pred = QwenImagePipeline._unpack_latents(
Expand Down
Loading