[Bugfix] Enable step-wise execution#81
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for pre-tokenized prompt IDs and step-wise execution within the vllm_omni_rollout_adapter.py. Key additions include the _extract_prompt_ids, _tokenize_text_prompt, and prepare_encode methods, which facilitate the initialization of diffusion states from tokenized inputs. Additionally, the async server configuration now enables step_execution and sets a maximum sequence limit. Feedback indicates that since step_execution is enabled, the adapter must also override the per-step execution methods to ensure RL-specific fields are correctly collected. Other suggestions include correcting a class name typo in a ValueError and utilizing the build_img_shapes helper function to reduce code duplication.
| ).to(self.device) | ||
| return tokens.input_ids, tokens.attention_mask | ||
|
|
||
| def prepare_encode( |
There was a problem hiding this comment.
To fully support step-wise execution for RL rollouts, this class likely needs to override the per-step execution method (e.g., execute_step or step). The current implementation overrides diffuse to collect all_log_probs and all_latents during the denoising loop. If step_execution is enabled in the engine, the engine will bypass diffuse and call the per-step method instead. Without an override that performs similar data collection and state updates (like incrementing state.step_index), these RL-specific fields will be missing from the final output.
| raise ValueError( | ||
| "QwenImagePipelineWithLogProbForTest.prepare_encode requires either " | ||
| "'prompt_ids' or a text 'prompt' in state.prompts[0]." | ||
| ) |
There was a problem hiding this comment.
The error message contains a typo in the class name, referring to QwenImagePipelineWithLogProbForTest instead of QwenImagePipelineWithLogProb.
| raise ValueError( | |
| "QwenImagePipelineWithLogProbForTest.prepare_encode requires either " | |
| "'prompt_ids' or a text 'prompt' in state.prompts[0]." | |
| ) | |
| raise ValueError( | |
| "QwenImagePipelineWithLogProb.prepare_encode requires either " | |
| "'prompt_ids' or a text 'prompt' in state.prompts[0]." | |
| ) |
| None, | ||
| ) | ||
|
|
||
| img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size |
There was a problem hiding this comment.
This logic for building img_shapes is already implemented in the build_img_shapes utility function in common.py. It is better to use the utility to avoid code duplication and ensure consistency.
| img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size | |
| img_shapes = build_img_shapes(height, width, batch_size, self.vae_scale_factor) |
References
- Avoid code duplication by reusing existing helper functions for common logic, such as constructing image shapes.
Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>
There was a problem hiding this comment.
I don't really understand why we need to add so many things to make it work..
There was a problem hiding this comment.
The main problem is vllm-omni does not official support passing prompt_token_ids as input.
The main logic of pipeline still try to tokenize the prompt.
That why the custom pipeline need to align with new function prepare_encode added by step-wise function.
To make this stable compatible, support from vllm-omni side is better
There was a problem hiding this comment.
I see. So vllm-omni does not support truly --skip-tokenizer-init in https://docs.vllm.ai/en/stable/configuration/engine_args/#modelconfig to accept prompt_token_ids.
Can we make a feature request for this?
There was a problem hiding this comment.
Yes, I will work on it
There was a problem hiding this comment.
does it conflict with the change with #66? which rename prompt_ids → prompt_token_ids for vllm-omni 0.20+
What does this PR do?
Temporary fix to enable stepwise execution for vllm-omni
A temporary fix to enable stepwise execution
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,vllm_omni,rollout,trainer,ci,training_utils,recipe,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,diffusion,omni,tests,docker,like[diffusion, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][diffusion, fsdp] feat: new rollout schedulerTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always