-
Notifications
You must be signed in to change notification settings - Fork 2k
[GKD] Use vllm for the student model #3475
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
trl/trainer/gkd_config.py
Outdated
teacher_vllm_mode (`str`, *optional*, defaults to `"server"`): | ||
Mode for teacher vLLM integration. Either `"server"` (connect to a running TRL vLLM server) or | ||
`"colocate"` (run vLLM in the same process). | ||
teacher_vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if it's worth having separate vLLM arg configs to reuse across the GRPO/GKD trainers?
else: | ||
raise ValueError(f"Unknown student_vllm_mode: {self.student_vllm_mode}") | ||
self.student_vllm_guided_decoding_regex = args.student_vllm_guided_decoding_regex | ||
self.student_vllm_sync_frequency = args.student_vllm_sync_frequency |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have a feel for the impact of generating with a stale student policy?
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Adds support for using vLLM for the student model’s on-policy generation in GKDTrainer
.
- Introduces new
student_use_vllm
flags and parameters inGKDConfig
, with validation. - Extends
GKDTrainer
to initialize vLLM in server or colocate mode, generate completions via vLLM, and sync weights. - Updates documentation with an “Accelerated Generation with vLLM” section and refines
generalized_jsd_loss
.
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
File | Description |
---|---|
trl/trainer/gkd_config.py | Added student_use_vllm* fields, metadata, and max_new_tokens check |
trl/trainer/gkd_trainer.py | Integrated vLLM setup, generation, and parameter sync logic |
docs/source/gkd_trainer.md | Documented vLLM server vs. co-locate modes and usage guidance |
Comments suppressed due to low confidence (2)
trl/trainer/gkd_trainer.py:354
- The new vLLM-based generation path is complex and critical but has no accompanying tests. Consider adding unit or integration tests covering both 'server' and 'colocate' modes, sync timing, and error branches (e.g., unknown mode).
def _generate_on_policy_outputs_student_vllm(self, inputs, generation_config, pad_token_id=None):
trl/trainer/gkd_config.py:133
- It may be helpful to validate
student_vllm_mode
in__post_init__
so only 'server' or 'colocate' are accepted, giving users early feedback on invalid values.
student_vllm_mode: str = field(
) | ||
|
||
# Create subgroups of ranks for TP | ||
self.student_tp_group, _ = torch.distributed.new_subgroups_by_enumeration( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new_subgroups_by_enumeration
returns a list of ProcessGroups and a default group. Assigning the whole list to student_tp_group
means all_gather_object
later gets a list instead of a single ProcessGroup. You should pick the subgroup for the current rank (e.g., by matching torch.distributed.get_rank()
against the ranks in each subgroup) before using it.
self.student_tp_group, _ = torch.distributed.new_subgroups_by_enumeration( | |
subgroups, _ = torch.distributed.new_subgroups_by_enumeration( |
Copilot uses AI. Check for mistakes.
@@ -145,6 +158,70 @@ def __init__( | |||
): | |||
self.generation_config.eos_token_id = self.model.generation_config.eos_token_id | |||
|
|||
# vLLM setup for student model if enabled | |||
self.student_use_vllm = args.student_use_vllm | |||
if self.student_use_vllm: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.student_vllm_client
is only set on the main process for server mode, so other ranks won't have this attribute. To avoid potential AttributeError
during error handling or shutdown, initialize student_vllm_client=None
unconditionally before the is_main_process
block.
Copilot uses AI. Check for mistakes.
What does this PR do?
Adds an option to use vLLM for the teacher model
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.