Skip to content

[Fix]Allow prompt and prior_token_ids to be provided simultaneously in GlmImagePipeline#13092

Open
JaredforReal wants to merge 2 commits intohuggingface:mainfrom
JaredforReal:fix/image-input
Open

[Fix]Allow prompt and prior_token_ids to be provided simultaneously in GlmImagePipeline#13092
JaredforReal wants to merge 2 commits intohuggingface:mainfrom
JaredforReal:fix/image-input

Conversation

@JaredforReal
Copy link
Contributor

@JaredforReal JaredforReal commented Feb 7, 2026

What does this PR do?

Previously, GlmImagePipeline.check_inputs() enforced mutual exclusion between prompt and prior_token_ids, raising a ValueError if both were provided. This was unnecessarily restrictive — there is a valid use case where a user wants to supply pre-computed prior_token_ids (to skip the expensive AR generation step) while still passing prompt so that prompt_embeds (glyph embeddings) can be derived from it automatically.

This PR relaxes that constraint:

  • Removed the check that raises when both prompt and prior_token_ids are provided.
  • Relaxed the prior_token_ids requires prompt_embeds check — now either prompt or prompt_embeds satisfies the requirement, since prompt will be used to generate prompt_embeds via encode_prompt() downstream.

Supported input combinations after this change

Inputs Behavior Status
prompt only AR generates prior_token_ids; glyph encoder generates prompt_embeds ✅ (unchanged)
prior_token_ids + prompt_embeds Both used directly, AR step skipped ✅ (unchanged)
prompt + prior_token_ids prior_token_ids used directly (AR skipped); prompt generates prompt_embeds new
Neither prompt nor prior_token_ids ValueError ✅ (unchanged)
prior_token_ids alone (no prompt/prompt_embeds) ValueError ✅ (unchanged)
prompt + prompt_embeds ValueError ✅ (unchanged)

Tests

Added two new tests to test_glm_image.py:

  • test_prompt_with_prior_token_ids: End-to-end test that first generates prior_token_ids via generate_prior_tokens(), then runs the full pipeline with both prompt and prior_token_ids provided together.
  • test_check_inputs_rejects_invalid_combinations: Unit test verifying that the three invalid input combinations listed above still correctly raise ValueError.

All existing tests continue to pass.

Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Copilot AI review requested due to automatic review settings February 7, 2026 03:48
@JaredforReal
Copy link
Contributor Author

@yiyixuxu PTAL, thanks

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the GLM image pipeline input validation to allow supplying prompt together with precomputed prior_token_ids, enabling users to skip the AR prior generation step while still deriving glyph prompt_embeds from the prompt.

Changes:

  • Removed the mutual-exclusion validation between prompt and prior_token_ids in GlmImagePipeline.check_inputs().
  • Relaxed the prior_token_ids validation to accept either prompt or prompt_embeds as the source for glyph embeddings.
  • Added tests covering the newly supported prompt + prior_token_ids combination and validating invalid input combinations.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
src/diffusers/pipelines/glm_image/pipeline_glm_image.py Adjusts check_inputs() to permit prompt and prior_token_ids together and updates validation/error paths.
tests/pipelines/glm_image/test_glm_image.py Adds new tests for the supported combo and for rejecting invalid input combinations.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +300 to +315
inputs_prompt_only = {
"prompt": "A photo of a cat",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 1.5,
"height": height,
"width": width,
"max_sequence_length": 16,
"output_type": "pt",
}
prior_token_ids, _, _ = pipe.generate_prior_tokens(
prompt="A photo of a cat",
height=height,
width=width,
device=torch.device(device),
generator=torch.Generator(device=device).manual_seed(0),
Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inputs_prompt_only (and its generator) is created but never used. This makes the test misleading (it implies a full pipeline run in step 1) and adds dead code; consider removing it or actually using it to run pipe(**inputs_prompt_only) if that's the intended setup.

Suggested change
inputs_prompt_only = {
"prompt": "A photo of a cat",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 1.5,
"height": height,
"width": width,
"max_sequence_length": 16,
"output_type": "pt",
}
prior_token_ids, _, _ = pipe.generate_prior_tokens(
prompt="A photo of a cat",
height=height,
width=width,
device=torch.device(device),
generator=torch.Generator(device=device).manual_seed(0),
prior_token_ids, _, _ = pipe.generate_prior_tokens(
prompt="A photo of a cat",
height=height,
width=width,
device=torch.device(device),
generator=generator,

Copilot uses AI. Check for mistakes.
Comment on lines +692 to 694
if prior_token_ids is not None and prompt_embeds is None and prompt is None:
raise ValueError("`prompt_embeds` or `prompt` must also be provided with `prior_token_ids`.")

Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The prior_token_ids validation at the end of check_inputs is effectively unreachable because an earlier check already raises whenever both prompt and prompt_embeds are None. If the goal is a more specific error when prior_token_ids is provided without any prompt inputs, consider folding this condition into the earlier prompt/prompt_embeds validation (or removing this block to avoid redundant code).

Suggested change
if prior_token_ids is not None and prompt_embeds is None and prompt is None:
raise ValueError("`prompt_embeds` or `prompt` must also be provided with `prior_token_ids`.")

Copilot uses AI. Check for mistakes.
" only forward one of the two."
)
elif prompt is None and prior_token_ids is None:
if prompt is None and prior_token_ids is None:
Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error raised when prompt is None and prior_token_ids is None can be confusing if a user provided prompt_embeds: earlier validation allows prompt_embeds in place of prompt, but this message implies prompt is required. Consider clarifying the message to indicate that prompt_embeds alone is not sufficient because prompt is needed to generate prior_token_ids when they are not provided.

Suggested change
if prompt is None and prior_token_ids is None:
if prompt is None and prior_token_ids is None:
# At this point, `prompt_embeds` is guaranteed to be not None (the earlier check
# `elif prompt is None and prompt_embeds is None` would already have raised),
# so clarify that `prompt_embeds` alone is not sufficient to derive `prior_token_ids`.
if prompt_embeds is not None:
raise ValueError(
"You have provided `prompt_embeds` without `prior_token_ids`, but a text `prompt` is also "
"required so that `prior_token_ids` can be generated. Please provide either a text `prompt` "
"so the pipeline can compute `prior_token_ids`, or pass `prior_token_ids` explicitly."
)

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant