Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e18ba0d
sglang rollout in fully async
AniZpZ Nov 19, 2025
0aec584
upd
AniZpZ Nov 21, 2025
48103dc
minor fix
AniZpZ Nov 21, 2025
c9e7d09
Merge remote-tracking branch 'origin/main' into recipe/async_policy_s…
AniZpZ Dec 5, 2025
815ebb2
Merge branch 'main' into recipe/async_policy_sglang
AniZpZ Dec 5, 2025
814be85
fix ValueError and MemoryAllocateError
Dec 17, 2025
ce64a14
Merge remote-tracking branch 'origin/main' into recipe/async_policy_s…
AniZpZ Dec 17, 2025
a1ae8db
fmt
AniZpZ Dec 17, 2025
f6c7589
Keep the original logic of enable_memory_saver and reduce the paramet…
Dec 18, 2025
0b13a30
Merge remote-tracking branch 'origin/main' into recipe/async_policy_s…
AniZpZ Dec 19, 2025
d6e7ab7
update clear_kv_cache
AniZpZ Dec 19, 2025
4492587
upd
AniZpZ Dec 24, 2025
2b0c897
upda
AniZpZ Dec 24, 2025
5c68371
fmt
AniZpZ Dec 24, 2025
a3bdc97
Fixed some minor issues with using Megatron to replace FSDP.
Dec 30, 2025
8edae0d
Resolved some conflicts and improved parameter synchronization strate…
Jan 5, 2026
afb1ad9
Resolved conflict between the partial rollout and the memory saver fe…
Jan 6, 2026
f424edb
Fix the issue of incorrect calculation for rollouter/active_time.
Jan 7, 2026
8dc3303
Merge remote-tracking branch 'origin/main' into recipe/async_policy_s…
AniZpZ Jan 9, 2026
64aba78
Minor rectification after rebase.
Jan 9, 2026
03cc18d
fix pre-commit
jsfanfanfan Jan 10, 2026
dbe9a45
minor fix
jsfanfanfan Jan 11, 2026
0121750
minor fix
Jan 14, 2026
8257a08
minor fix
Jan 15, 2026
5721903
Merge branch 'volcengine:main' into recipe/async_policy_sglang
jsfanfanfan Jan 15, 2026
72ddbb1
fix TypeError: got an unexpected keyword argument 'base_gpu_id'
Jan 15, 2026
80f277f
Use lazy import to avoid ModuleNotFoundError
Jan 15, 2026
f9994a6
fix base_gpu_id missing!
Jan 16, 2026
67096c3
pre-commit
Jan 16, 2026
b0f905a
remove reduntant classes and lines.
Jan 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion recipe/fully_async_policy/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import ray
from omegaconf import DictConfig

from recipe.fully_async_policy.sglang_rollout.sglang_async_server import FullyAsyncSGLangReplica
from recipe.fully_async_policy.vllm_rollout.vllm_async_server import FullyAsyncvLLMReplica
from verl.experimental.agent_loop.agent_loop import (
AgentLoopManager,
Expand Down Expand Up @@ -215,7 +216,17 @@ def __init__(self, config: DictConfig, worker_group: RayWorkerGroup = None, rm_w
self.reward_model_manager = None
self.reward_router_address = None
self.agent_loop_workers_class = FullyAsyncAgentLoopWorker
self.rollout_replica_class = FullyAsyncvLLMReplica

# Select rollout replica class based on rollout name
rollout_name = config.actor_rollout_ref.rollout.name
if rollout_name == "sglang":
self.rollout_replica_class = FullyAsyncSGLangReplica
print("[FullyAsyncAgentLoopManager] SGLang replica class selected")
elif rollout_name == "vllm":
self.rollout_replica_class = FullyAsyncvLLMReplica
print("[FullyAsyncAgentLoopManager] vLLM replica class selected")
else:
raise ValueError(f"Unsupported rollout name: {rollout_name}. Supported values are 'sglang' and 'vllm'.")

self.rm_wg = rm_wg
self.rollout_replicas = None
Expand Down Expand Up @@ -323,5 +334,27 @@ async def wake_up(self):
async def sleep(self):
await asyncio.gather(*[replica.sleep() for replica in self.rollout_replicas])

async def reset_prefix_cache(self):
print("[FullyAsyncAgentLoopManager] Reset prefix cache ...")
# await asyncio.gather(*[replica.reset_prefix_cache() for replica in self.rollout_replicas])
# Note: debug
timeout = 5.0

async def reset_one(idx, replica):
print(f"[reset_prefix_cache] start replica={idx}")
try:
await asyncio.wait_for(replica.reset_prefix_cache(), timeout=timeout)
except asyncio.TimeoutError:
print(f"[reset_prefix_cache] TIMEOUT replica={idx} after {timeout}s")
return
except Exception as e:
print(f"[reset_prefix_cache] ERROR replica={idx}: {e!r}")
return
print(f"[reset_prefix_cache] done replica={idx}")

tasks = [reset_one(i, replica) for i, replica in enumerate(self.rollout_replicas)]
await asyncio.gather(*tasks, return_exceptions=True)
print("[FullyAsyncAgentLoopManager] Reset prefix cache finished")

async def clear_kv_cache(self):
await asyncio.gather(*[replica.clear_kv_cache() for replica in self.rollout_replicas])
120 changes: 111 additions & 9 deletions recipe/fully_async_policy/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import logging
import os
import threading
import time

import torch
Expand All @@ -33,7 +35,7 @@
load_fsdp_model_to_gpu,
offload_fsdp_model_to_cpu,
)
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker
from verl.workers.fsdp_workers import AsyncActorRolloutRefWorker, CriticWorker

from .checkpoint_engine import CheckpointEngine

Expand Down Expand Up @@ -67,6 +69,41 @@ def get_inference_model(rollout):


class DetachNcclSync(AsyncActorRolloutRefWorker):
def __init__(self, config: DictConfig, role: str):
super().__init__(config, role)

self._bg_loop = asyncio.new_event_loop()
self._bg_thread = threading.Thread(
target=self._start_background_loop, args=(self._bg_loop,), name="rollout_actor_async_worker", daemon=True
)
self._bg_thread.start()
logger.info(f"[DetachNcclSync] Background thread for SGLang sync started. PID: {os.getpid()}")

def _start_background_loop(self, loop):
asyncio.set_event_loop(loop)
try:
loop.run_forever()
except Exception as e:
logger.error(f"[DetachNcclSync] Background loop crashed: {e}")

def _run_async_safely(self, coro):
if not self._bg_thread.is_alive():
raise RuntimeError("Background thread for SGLang sync is not running!")

future = asyncio.run_coroutine_threadsafe(coro, self._bg_loop)

return future.result()

def _flush_sglang_batch(self, inference_model, batch_data):
batch_copy = list(batch_data)
self._run_async_safely(self.update_weights(inference_model, batch_copy))

def __del__(self):
if hasattr(self, "_bg_loop") and self._bg_loop.is_running():
self._bg_loop.call_soon_threadsafe(self._bg_loop.stop)
if hasattr(self, "_bg_thread") and self._bg_thread.is_alive():
self._bg_thread.join(timeout=1.0)

@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
def init_checkpoint_engine(self, rank_offset: int, actor_num: int, rollout_num: int):
current_rank = torch.distributed.get_rank() + rank_offset
Expand All @@ -89,12 +126,46 @@ def sync_rollout_weights(self, sync_group_name="actor_rollout"):
if self._is_actor and self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
params = self._get_actor_params() if self._is_actor else None
rollout_name = self.config.rollout.name

inference_model = None
if self._is_rollout:
inference_model = get_inference_model(self.rollout)
if rollout_name == "vllm":
inference_model = get_inference_model(self.rollout)

from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader

patch_vllm_moe_model_weight_loader(inference_model)
elif rollout_name == "sglang":
inference_model = self.rollout._engine
# For ServerAdapter, _engine might be None and needs async initialization
if inference_model is None:
# Initialize the server adapter engine
print("[sync_rollout_weights] Initialize server adapter engine")

async def init_engine():
if hasattr(self.rollout, "_init_server_adapter"):
await self.rollout._init_server_adapter()
else:
print("[sync_rollout_weights] No _init_server_adapter method found")
return self.rollout._engine

inference_model = self._run_async_safely(init_engine())
if inference_model is None:
raise RuntimeError(
f"Failed to initialize rollout engine. "
f"rollout type: {type(self.rollout)}, "
f"has _init_server_adapter: {hasattr(self.rollout, '_init_server_adapter')}"
)
else:
raise NotImplementedError(f"Unknown rollout name: {rollout_name}")

from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader
from ray.util.collective import collective

max_bucket_bytes = 8 * 1024 * 1024 * 1024 # 8GB
bucket = []
bucket_bytes = 0

patch_vllm_moe_model_weight_loader(inference_model)
for key, shape, dtype in self._weights_info:
tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())
if self._is_actor:
Expand All @@ -104,16 +175,43 @@ def sync_rollout_weights(self, sync_group_name="actor_rollout"):
origin_data = origin_data.full_tensor()
if torch.distributed.get_rank() == 0:
tensor.copy_(origin_data)
from ray.util.collective import collective
collective.broadcast(tensor, src_rank=0, group_name=sync_group_name)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be indented

Suggested change
collective.broadcast(tensor, src_rank=0, group_name=sync_group_name)
collective.broadcast(tensor, src_rank=0, group_name=sync_group_name)


collective.broadcast(tensor, src_rank=0, group_name=sync_group_name)
if self._is_rollout:
inference_model.load_weights([(key, tensor)])
if self._is_rollout:
# batching tensors for sglang
if rollout_name == "sglang":
bucket.append((key, tensor))
bucket_bytes += tensor.numel() * tensor.element_size()

if bucket_bytes >= max_bucket_bytes:
self._flush_sglang_batch(inference_model, bucket)
bucket.clear()
bucket_bytes = 0
else:
if rollout_name == "vllm":
inference_model.load_weights([(key, tensor)])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be indented

Suggested change
if self._is_rollout:
# batching tensors for sglang
if rollout_name == "sglang":
bucket.append((key, tensor))
bucket_bytes += tensor.numel() * tensor.element_size()
if bucket_bytes >= max_bucket_bytes:
self._flush_sglang_batch(inference_model, bucket)
bucket.clear()
bucket_bytes = 0
else:
if rollout_name == "vllm":
inference_model.load_weights([(key, tensor)])
if self._is_rollout:
# batching tensors for sglang
if rollout_name == "sglang":
bucket.append((key, tensor))
bucket_bytes += tensor.numel() * tensor.element_size()
if bucket_bytes >= max_bucket_bytes:
self._flush_sglang_batch(inference_model, bucket)
bucket.clear()
bucket_bytes = 0
elif rollout_name == "vllm":
inference_model.load_weights([(key, tensor)])


if self._is_rollout and rollout_name == "sglang" and bucket:
self._flush_sglang_batch(inference_model, bucket)
bucket.clear()

if self._is_actor and self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
get_torch_device().empty_cache()

async def update_weights(self, inference_engine, params):
from sglang.srt.weight_sync.utils import update_weights as sgl_update_weights

await sgl_update_weights(
engine=inference_engine,
params_batch=params,
device_mesh_key="infer_tp",
device_mesh=self.rollout_device_mesh,
)

if self.rollout_device_mesh["infer_tp"].get_local_rank() == 0:
await inference_engine.flush_cache()

def cache_actor_weights_to_cpu(self):
self.cpu_named_params = {}
if self._is_actor:
Expand Down Expand Up @@ -182,6 +280,10 @@ def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"):


class DetachActorWorker(DetachNcclSync):
def __init__(self, config: DictConfig, role: str):
print("[DetachAsyncRolloutWorker] Initializing via DetachNcclSync...")
DetachNcclSync.__init__(self, config, role)

def _get_actor_params(self):
assert self._is_actor
params = self.actor_module_fsdp.state_dict()
Expand Down Expand Up @@ -233,7 +335,7 @@ def clear_cpu_model(self, n):
class DetachAsyncRolloutWorker(DetachNcclSync):
def __init__(self, config: DictConfig, role: str):
print(f"[DetachAsyncRolloutWorker] {DetachAsyncRolloutWorker.__mro__}")
ActorRolloutRefWorker.__init__(self, config, role)
DetachNcclSync.__init__(self, config, role)

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def set_actor_weights_info(self, weights_info):
Expand Down
4 changes: 3 additions & 1 deletion recipe/fully_async_policy/fully_async_rollouter.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,16 +669,18 @@ async def _should_pause_generation(self) -> bool:

async def pause(self):
"""pause rollout"""
print("[FullyAsyncRollouter][Public][Pause]")
print("[FullyAsyncRollouter][Public][Pause] partial rollout: ", {self.config.async_training.partial_rollout})
async with self.lock:
self.paused = True
# Cancel all rollout tasks
if self.config.async_training.partial_rollout:
await self.async_rollout_manager.cancel()
print("[FullyAsyncRollouter][Public][Pause] Unfinished rollout tasks canceled")
if self.active_tasks:
await asyncio.gather(*self.active_tasks, return_exceptions=True)
self.active_tasks.clear()
print("[FullyAsyncRollouter][Public][Pause] All active tasks completed")
print("[FullyAsyncRollouter][Public][Pause] Prefix cache reset")
await self.async_rollout_manager.clear_kv_cache()
self.monitor_loop_trigger = False

Expand Down
Loading