-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[sglang, rollout] feat: support sglang as rollout engine in fully async policy #4191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
wuxibin89
merged 30 commits into
verl-project:main
from
meituan-search:recipe/async_policy_sglang
Jan 19, 2026
Merged
Changes from 14 commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
e18ba0d
sglang rollout in fully async
AniZpZ 0aec584
upd
AniZpZ 48103dc
minor fix
AniZpZ c9e7d09
Merge remote-tracking branch 'origin/main' into recipe/async_policy_s…
AniZpZ 815ebb2
Merge branch 'main' into recipe/async_policy_sglang
AniZpZ 814be85
fix ValueError and MemoryAllocateError
ce64a14
Merge remote-tracking branch 'origin/main' into recipe/async_policy_s…
AniZpZ a1ae8db
fmt
AniZpZ f6c7589
Keep the original logic of enable_memory_saver and reduce the paramet…
0b13a30
Merge remote-tracking branch 'origin/main' into recipe/async_policy_s…
AniZpZ d6e7ab7
update clear_kv_cache
AniZpZ 4492587
upd
AniZpZ 2b0c897
upda
AniZpZ 5c68371
fmt
AniZpZ a3bdc97
Fixed some minor issues with using Megatron to replace FSDP.
8edae0d
Resolved some conflicts and improved parameter synchronization strate…
afb1ad9
Resolved conflict between the partial rollout and the memory saver fe…
f424edb
Fix the issue of incorrect calculation for rollouter/active_time.
8dc3303
Merge remote-tracking branch 'origin/main' into recipe/async_policy_s…
AniZpZ 64aba78
Minor rectification after rebase.
03cc18d
fix pre-commit
jsfanfanfan dbe9a45
minor fix
jsfanfanfan 0121750
minor fix
8257a08
minor fix
5721903
Merge branch 'volcengine:main' into recipe/async_policy_sglang
jsfanfanfan 72ddbb1
fix TypeError: got an unexpected keyword argument 'base_gpu_id'
80f277f
Use lazy import to avoid ModuleNotFoundError
f9994a6
fix base_gpu_id missing!
67096c3
pre-commit
b0f905a
remove reduntant classes and lines.
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -13,8 +13,10 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| # limitations under the License. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| import asyncio | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| import logging | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| import threading | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| import time | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -33,7 +35,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||
| load_fsdp_model_to_gpu, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| offload_fsdp_model_to_cpu, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| from verl.workers.fsdp_workers import AsyncActorRolloutRefWorker, CriticWorker | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| from .checkpoint_engine import CheckpointEngine | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -67,6 +69,41 @@ def get_inference_model(rollout): | |||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| class DetachNcclSync(AsyncActorRolloutRefWorker): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, config: DictConfig, role: str): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| super().__init__(config, role) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._bg_loop = asyncio.new_event_loop() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._bg_thread = threading.Thread( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| target=self._start_background_loop, args=(self._bg_loop,), name="rollout_actor_async_worker", daemon=True | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._bg_thread.start() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.info(f"[DetachNcclSync] Background thread for SGLang sync started. PID: {os.getpid()}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _start_background_loop(self, loop): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| asyncio.set_event_loop(loop) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| loop.run_forever() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.error(f"[DetachNcclSync] Background loop crashed: {e}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _run_async_safely(self, coro): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not self._bg_thread.is_alive(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise RuntimeError("Background thread for SGLang sync is not running!") | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| future = asyncio.run_coroutine_threadsafe(coro, self._bg_loop) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| return future.result() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _flush_sglang_batch(self, inference_model, batch_data): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| batch_copy = list(batch_data) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._run_async_safely(self.update_weights(inference_model, batch_copy)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __del__(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if hasattr(self, "_bg_loop") and self._bg_loop.is_running(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._bg_loop.call_soon_threadsafe(self._bg_loop.stop) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if hasattr(self, "_bg_thread") and self._bg_thread.is_alive(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._bg_thread.join(timeout=1.0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| def init_checkpoint_engine(self, rank_offset: int, actor_num: int, rollout_num: int): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| current_rank = torch.distributed.get_rank() + rank_offset | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -89,12 +126,46 @@ def sync_rollout_weights(self, sync_group_name="actor_rollout"): | |||||||||||||||||||||||||||||||||||||||||||||||||||
| if self._is_actor and self._is_offload_param: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| load_fsdp_model_to_gpu(self.actor_module_fsdp) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| params = self._get_actor_params() if self._is_actor else None | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| rollout_name = self.config.rollout.name | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| inference_model = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self._is_rollout: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| inference_model = get_inference_model(self.rollout) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if rollout_name == "vllm": | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| inference_model = get_inference_model(self.rollout) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| patch_vllm_moe_model_weight_loader(inference_model) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif rollout_name == "sglang": | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| inference_model = self.rollout._engine | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| # For ServerAdapter, _engine might be None and needs async initialization | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if inference_model is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Initialize the server adapter engine | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| print("[sync_rollout_weights] Initialize server adapter engine") | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def init_engine(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if hasattr(self.rollout, "_init_server_adapter"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| await self.rollout._init_server_adapter() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| print("[sync_rollout_weights] No _init_server_adapter method found") | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| return self.rollout._engine | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| inference_model = self._run_async_safely(init_engine()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if inference_model is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise RuntimeError( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"Failed to initialize rollout engine. " | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"rollout type: {type(self.rollout)}, " | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"has _init_server_adapter: {hasattr(self.rollout, '_init_server_adapter')}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise NotImplementedError(f"Unknown rollout name: {rollout_name}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| from ray.util.collective import collective | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_bucket_bytes = 8 * 1024 * 1024 * 1024 # 8GB | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| bucket = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| bucket_bytes = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| patch_vllm_moe_model_weight_loader(inference_model) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| for key, shape, dtype in self._weights_info: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self._is_actor: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -104,16 +175,43 @@ def sync_rollout_weights(self, sync_group_name="actor_rollout"): | |||||||||||||||||||||||||||||||||||||||||||||||||||
| origin_data = origin_data.full_tensor() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if torch.distributed.get_rank() == 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| tensor.copy_(origin_data) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| from ray.util.collective import collective | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| collective.broadcast(tensor, src_rank=0, group_name=sync_group_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| collective.broadcast(tensor, src_rank=0, group_name=sync_group_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self._is_rollout: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| inference_model.load_weights([(key, tensor)]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self._is_rollout: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| # batching tensors for sglang | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if rollout_name == "sglang": | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| bucket.append((key, tensor)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| bucket_bytes += tensor.numel() * tensor.element_size() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| if bucket_bytes >= max_bucket_bytes: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._flush_sglang_batch(inference_model, bucket) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| bucket.clear() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| bucket_bytes = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if rollout_name == "vllm": | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| inference_model.load_weights([(key, tensor)]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be indented
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self._is_rollout and rollout_name == "sglang" and bucket: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._flush_sglang_batch(inference_model, bucket) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| bucket.clear() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self._is_actor and self._is_offload_param: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| offload_fsdp_model_to_cpu(self.actor_module_fsdp) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| get_torch_device().empty_cache() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def update_weights(self, inference_engine, params): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| from sglang.srt.weight_sync.utils import update_weights as sgl_update_weights | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| await sgl_update_weights( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| engine=inference_engine, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| params_batch=params, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| device_mesh_key="infer_tp", | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| device_mesh=self.rollout_device_mesh, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.rollout_device_mesh["infer_tp"].get_local_rank() == 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| await inference_engine.flush_cache() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| def cache_actor_weights_to_cpu(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.cpu_named_params = {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self._is_actor: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -182,6 +280,10 @@ def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"): | |||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| class DetachActorWorker(DetachNcclSync): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, config: DictConfig, role: str): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| print("[DetachAsyncRolloutWorker] Initializing via DetachNcclSync...") | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| DetachNcclSync.__init__(self, config, role) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _get_actor_params(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert self._is_actor | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| params = self.actor_module_fsdp.state_dict() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -233,7 +335,7 @@ def clear_cpu_model(self, n): | |||||||||||||||||||||||||||||||||||||||||||||||||||
| class DetachAsyncRolloutWorker(DetachNcclSync): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, config: DictConfig, role: str): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| print(f"[DetachAsyncRolloutWorker] {DetachAsyncRolloutWorker.__mro__}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| ActorRolloutRefWorker.__init__(self, config, role) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| DetachNcclSync.__init__(self, config, role) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| def set_actor_weights_info(self, weights_info): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be indented