Skip to content

[GKD] Use vllm for the student model #3475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open

[GKD] Use vllm for the student model #3475

wants to merge 26 commits into from

Conversation

kashif
Copy link
Collaborator

@kashif kashif commented May 21, 2025

What does this PR do?

Adds an option to use vLLM for the teacher model

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@kashif kashif marked this pull request as draft May 21, 2025 13:21
teacher_vllm_mode (`str`, *optional*, defaults to `"server"`):
Mode for teacher vLLM integration. Either `"server"` (connect to a running TRL vLLM server) or
`"colocate"` (run vLLM in the same process).
teacher_vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it's worth having separate vLLM arg configs to reuse across the GRPO/GKD trainers?

else:
raise ValueError(f"Unknown student_vllm_mode: {self.student_vllm_mode}")
self.student_vllm_guided_decoding_regex = args.student_vllm_guided_decoding_regex
self.student_vllm_sync_frequency = args.student_vllm_sync_frequency
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a feel for the impact of generating with a stale student policy?

@kashif kashif changed the title [GKD] Use vllm for the teacher model [GKD] Use vllm for the student model Jun 4, 2025
@kashif kashif marked this pull request as ready for review June 23, 2025 10:58
@kashif kashif requested a review from Copilot June 23, 2025 11:03
Copilot

This comment was marked as outdated.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@kashif kashif requested a review from Copilot June 26, 2025 15:25
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Adds support for using vLLM for the student model’s on-policy generation in GKDTrainer.

  • Introduces new student_use_vllm flags and parameters in GKDConfig, with validation.
  • Extends GKDTrainer to initialize vLLM in server or colocate mode, generate completions via vLLM, and sync weights.
  • Updates documentation with an “Accelerated Generation with vLLM” section and refines generalized_jsd_loss.

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

File Description
trl/trainer/gkd_config.py Added student_use_vllm* fields, metadata, and max_new_tokens check
trl/trainer/gkd_trainer.py Integrated vLLM setup, generation, and parameter sync logic
docs/source/gkd_trainer.md Documented vLLM server vs. co-locate modes and usage guidance
Comments suppressed due to low confidence (2)

trl/trainer/gkd_trainer.py:354

  • The new vLLM-based generation path is complex and critical but has no accompanying tests. Consider adding unit or integration tests covering both 'server' and 'colocate' modes, sync timing, and error branches (e.g., unknown mode).
    def _generate_on_policy_outputs_student_vllm(self, inputs, generation_config, pad_token_id=None):

trl/trainer/gkd_config.py:133

  • It may be helpful to validate student_vllm_mode in __post_init__ so only 'server' or 'colocate' are accepted, giving users early feedback on invalid values.
    student_vllm_mode: str = field(

)

# Create subgroups of ranks for TP
self.student_tp_group, _ = torch.distributed.new_subgroups_by_enumeration(
Copy link
Preview

Copilot AI Jun 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new_subgroups_by_enumeration returns a list of ProcessGroups and a default group. Assigning the whole list to student_tp_group means all_gather_object later gets a list instead of a single ProcessGroup. You should pick the subgroup for the current rank (e.g., by matching torch.distributed.get_rank() against the ranks in each subgroup) before using it.

Suggested change
self.student_tp_group, _ = torch.distributed.new_subgroups_by_enumeration(
subgroups, _ = torch.distributed.new_subgroups_by_enumeration(

Copilot uses AI. Check for mistakes.

@@ -145,6 +158,70 @@ def __init__(
):
self.generation_config.eos_token_id = self.model.generation_config.eos_token_id

# vLLM setup for student model if enabled
self.student_use_vllm = args.student_use_vllm
if self.student_use_vllm:
Copy link
Preview

Copilot AI Jun 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.student_vllm_client is only set on the main process for server mode, so other ranks won't have this attribute. To avoid potential AttributeError during error handling or shutdown, initialize student_vllm_client=None unconditionally before the is_main_process block.

Copilot uses AI. Check for mistakes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants