Skip to content

Commit 7730471

Browse files
yuanheng-zhaowtomin
authored andcommitted
[Bugfix] Enable teacahce in QwenImageEditPlusPipeline (vllm-project#379)
Signed-off-by: yuanheng <jonathan.zhaoyh@gmail.com> Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com>
1 parent e6f9189 commit 7730471

File tree

3 files changed

+33
-22
lines changed

3 files changed

+33
-22
lines changed

docs/user_guide/acceleration/teacache.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ th {
108108
|--------------|--------|-------------------|
109109
| `QwenImagePipeline` | Qwen-Image | `Qwen/Qwen-Image` |
110110
| `QwenImageEditPipeline` | Qwen-Image-Edit | `Qwen/Qwen-Image-Edit` |
111+
| `QwenImageEditPlusPipeline` | Qwen-Image-Edit | `Qwen/Qwen-Image-Edit-2509` |
111112

112113
### Coming Soon
113114

vllm_omni/diffusion/cache/teacache/extractors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ def postprocess(h):
262262
EXTRACTOR_REGISTRY: dict[str, Callable] = {
263263
"QwenImagePipeline": extract_qwen_context,
264264
"QwenImageEditPipeline": extract_qwen_context,
265+
"QwenImageEditPlusPipeline": extract_qwen_context,
265266
# Future models:
266267
# "FluxPipeline": extract_flux_context,
267268
# "CogVideoXPipeline": extract_cogvideox_context,

vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def __init__(
194194
)
195195

196196
self.stage = None
197+
self._cache_backend = None
197198

198199
self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
199200
self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16
@@ -543,31 +544,39 @@ def diffuse(
543544
if image_latents is not None:
544545
latent_model_input = torch.cat([latents, image_latents], dim=1)
545546

546-
noise_pred = self.transformer(
547-
hidden_states=latent_model_input,
548-
timestep=timestep / 1000,
549-
guidance=guidance,
550-
encoder_hidden_states_mask=prompt_embeds_mask,
551-
encoder_hidden_states=prompt_embeds,
552-
img_shapes=img_shapes,
553-
txt_seq_lens=txt_seq_lens,
554-
attention_kwargs=self.attention_kwargs,
555-
return_dict=False,
556-
)[0]
547+
transformer_kwargs = {
548+
"hidden_states": latent_model_input,
549+
"timestep": timestep / 1000,
550+
"guidance": guidance,
551+
"encoder_hidden_states_mask": prompt_embeds_mask,
552+
"encoder_hidden_states": prompt_embeds,
553+
"img_shapes": img_shapes,
554+
"txt_seq_lens": txt_seq_lens,
555+
"attention_kwargs": self.attention_kwargs,
556+
"return_dict": False,
557+
}
558+
if self._cache_backend is not None:
559+
transformer_kwargs["cache_branch"] = "positive"
560+
561+
noise_pred = self.transformer(**transformer_kwargs)[0]
557562
noise_pred = noise_pred[:, : latents.size(1)]
558563

559564
if do_true_cfg:
560-
neg_noise_pred = self.transformer(
561-
hidden_states=latent_model_input,
562-
timestep=timestep / 1000,
563-
guidance=guidance,
564-
encoder_hidden_states_mask=negative_prompt_embeds_mask,
565-
encoder_hidden_states=negative_prompt_embeds,
566-
img_shapes=img_shapes,
567-
txt_seq_lens=negative_txt_seq_lens,
568-
attention_kwargs=self.attention_kwargs,
569-
return_dict=False,
570-
)[0]
565+
neg_transformer_kwargs = {
566+
"hidden_states": latent_model_input,
567+
"timestep": timestep / 1000,
568+
"guidance": guidance,
569+
"encoder_hidden_states_mask": negative_prompt_embeds_mask,
570+
"encoder_hidden_states": negative_prompt_embeds,
571+
"img_shapes": img_shapes,
572+
"txt_seq_lens": negative_txt_seq_lens,
573+
"attention_kwargs": self.attention_kwargs,
574+
"return_dict": False,
575+
}
576+
if self._cache_backend is not None:
577+
neg_transformer_kwargs["cache_branch"] = "negative"
578+
579+
neg_noise_pred = self.transformer(**neg_transformer_kwargs)[0]
571580
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
572581
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
573582
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)

0 commit comments

Comments
 (0)