Skip to content

Commit d324b01

Browse files
authored
[fully_async] feat: reuse trainer worker group for hybrid rollout to do validation (#6076)
### Overview By dynamically adding or removing replicas at runtime, this PR fixes the `use_trainer_do_validate` capability that was broken in fully-async training mode. Furthermore, it provides the necessary infrastructure components for future elastic-scheduling / resilience building. ![architecture](https://github.com/user-attachments/assets/8826ac01-745c-4b62-bc40-fcecc294b0fb) ### What Changed #### 1. Merged Handle Registry into GlobalRequestLoadBalancer The original architecture used a local `servers: dict[str, ActorHandle]` cache inside each `LLMServerClient`. This made elastic scaling impossible without broadcasting updates to every client/worker. **Before (2 RPCs per acquire):** ``` server_id = LB.acquire(request_id) # RPC 1 handle = client.servers[server_id] # local lookup (stale if elastic add/remove happened) ``` **After (1 atomic RPC per acquire):** ``` (server_id, handle) = LB.acquire(request_id) # single RPC, always consistent ``` The `GlobalRequestLoadBalancer` now owns both the routing pool (`_inflight_requests`) and the handle mapping ( `_servers`). Elastic `add_replica()` / `remove_replica()` each require only **one Ray RPC** — no client/worker notification needed. #### 2. FullyAsyncLLMServerManager: Two-Phase Initialization + Elastic Lifecycle New subclass of `LLMServerManager` that supports: - **Phase 1 — Elastic hybrid replicas** (rank 0..N_e-1): Backed by trainer GPUs via injected worker group; initialized then immediately slept to free GPU memory for training - **Phase 2 — Fixed standalone replicas** (rank N_e..N_e+N_f-1): On dedicated rollout GPUs - Runtime `add_replica(resource_id)` / `remove_replica(resource_id)` with atomic LB operations #### 3. Trainer-Side Validation (`use_trainer_do_validate=True`) Previously broken/asserted-out in fully-async mode. Now fully functional via a three-phase validation cycle: | Phase | Action | |------------------------|-----------------------------------------------------------------------------------------------| | **1. TRAIN → ROLLOUT** | Sync weights → abort all replicas → activate elastic replicas in LB → resume generation | | **2. Validate** | Execute validation via RPC to rollouter | | **3. ROLLOUT → TRAIN** | Abort all replicas → deactivate elastic replicas → sleep elastic GPUs → resume fixed replicas | Uses a dedicated `hybrid_checkpoint_manager` (naive backend) for the elastic replica pool, separate from the existing `checkpoint_manager` for fixed rollout replicas. #### 4. KV-Cache-Only Weight Sync Optimization vLLM's `sleep(level=1)` mode allows restoring **weights first, then KV cache separately**. During parameter synchronization we now call `release_kv_cache` → NCCL sync → `resume_kv_cache` instead of the heavier full `sleep` → sync → `wake_up` cycle, reducing memory pressure. #### 5. Abort State Tracking Both vLLM and SGLang servers now track an `_is_aborted` flag: - `abort_all_requests()` sets it → subsequent `generate()` calls return immediately with `stop_reason="aborted"` - `resume_generation()` clears it - Prevents post-abort processing errors (e.g., `IndexError` on empty outputs) ### Design Goals 1. **Single Responsibility**: Manager owns lifecycle; LB handles routing + handle mapping (merged); Client only sends requests 2. **Elastic Convergence**: replica add/remove operates on a single LB Ray Actor (internal handle registry); no client/worker notification 3. **Elastic Resources**: `FullyAsyncLLMServerManager` implements elastic resource registration with two-phase init 4. **Trainer-side Validation**: `use_trainer_do_validate=True` supported via elastic hybrid replicas on trainer GPUs ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`, `vllm_omni`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward`, `fully_async`, `one_step_off` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/verl-project/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/verl-project/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/verl-project/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/verl-project/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [ ] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`.
1 parent 3780ef4 commit d324b01

17 files changed

Lines changed: 920 additions & 512 deletions

File tree

tests/checkpoint_engine/test_special_server_adapter.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@
2121

2222
from tests.checkpoint_engine.test_utils import create_trainer_worker_group
2323
from verl.checkpoint_engine import CheckpointEngineManager
24+
from verl.experimental.fully_async_policy.fully_async_rollouter import FullyAsyncLLMServerClient
2425
from verl.single_controller.ray import (
2526
RayResourcePool,
2627
)
2728
from verl.utils.config import omega_conf_to_dataclass
2829
from verl.workers.config import CheckpointEngineConfig, HFModelConfig
29-
from verl.workers.rollout.llm_server import FullyLLMServerClient, LLMServerClient, LLMServerManager
30+
from verl.workers.rollout.llm_server import LLMServerClient, LLMServerManager
3031

3132

3233
@pytest.fixture
@@ -123,7 +124,7 @@ async def _run_server_manager_without_resume(
123124
async def _run_server_manager_with_resume(
124125
initial_steps: int,
125126
train_steps: int,
126-
server_manager: FullyLLMServerClient,
127+
server_manager: FullyAsyncLLMServerClient,
127128
checkpoint_manager: CheckpointEngineManager,
128129
prompts: list[list[dict]],
129130
tokenizer: PreTrainedTokenizer,
@@ -231,7 +232,7 @@ async def test_server_adapter(init_config):
231232
await _run_server_manager_with_resume(
232233
initial_steps=4,
233234
train_steps=3,
234-
server_manager=llm_server_manager.get_client(fully_async=True),
235+
server_manager=llm_server_manager.get_client(client_cls=FullyAsyncLLMServerClient),
235236
checkpoint_manager=checkpoint_manager,
236237
prompts=prompts,
237238
tokenizer=model_config.tokenizer,

tests/checkpoint_engine/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(self, config: TrainingWorkerConfig, checkpoint_engine_config: Check
4040
self.checkpoint_engine = CheckpointEngineRegistry.new(backend, bucket_size=bucket_size, **engine_kwargs)
4141

4242
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
43-
async def update_weights(self, global_steps: int = None):
43+
async def update_weights(self, global_steps: int = None, mode: str = "auto"):
4444
per_tensor_param, _ = self.engine.get_per_tensor_param()
4545
await self.checkpoint_engine.send_weights(per_tensor_param)
4646

tests/experimental/agent_loop/test_basic_agent_loop.py

Lines changed: 105 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -422,51 +422,136 @@ class TestLoadBalancerRouting:
422422

423423
def test_distributes_across_servers(self, ray_for_lb):
424424
lb = GlobalRequestLoadBalancer.remote(servers={"s0": None, "s1": None, "s2": None})
425-
servers = [ray.get(lb.acquire_server.remote(request_id=f"r{i}")) for i in range(3)]
425+
servers = [ray.get(lb.acquire_server.remote(request_id=f"r{i}"))[0] for i in range(3)]
426426
assert sorted(servers) == ["s0", "s1", "s2"]
427427

428428
def test_new_requests_route_to_least_loaded(self, ray_for_lb):
429429
lb = GlobalRequestLoadBalancer.remote(servers={"s0": None, "s1": None, "s2": None})
430430
# Load s0 with 3 inflight requests
431-
ray.get(lb.acquire_server.remote(request_id="a")) # -> s0
432-
ray.get(lb.acquire_server.remote(request_id="a")) # sticky -> s0
433-
ray.get(lb.acquire_server.remote(request_id="a")) # sticky -> s0
431+
ray.get(lb.acquire_server.remote(request_id="a"))[0] # -> s0
432+
ray.get(lb.acquire_server.remote(request_id="a"))[0] # sticky -> s0
433+
ray.get(lb.acquire_server.remote(request_id="a"))[0] # sticky -> s0
434434
# Load s1 with 1 inflight request
435-
ray.get(lb.acquire_server.remote(request_id="b")) # -> s1
435+
ray.get(lb.acquire_server.remote(request_id="b"))[0] # -> s1
436436
# s2 has 0 inflight, so next new request must go to s2
437-
s_new = ray.get(lb.acquire_server.remote(request_id="d"))
437+
s_new = ray.get(lb.acquire_server.remote(request_id="d"))[0]
438438
assert s_new == "s2"
439439

440440
def test_release_rebalances(self, ray_for_lb):
441441
lb = GlobalRequestLoadBalancer.remote(servers={"s0": None, "s1": None})
442-
s0 = ray.get(lb.acquire_server.remote(request_id="r0"))
443-
s1 = ray.get(lb.acquire_server.remote(request_id="r1"))
442+
s0 = ray.get(lb.acquire_server.remote(request_id="r0"))[0]
443+
s1 = ray.get(lb.acquire_server.remote(request_id="r1"))[0]
444444
assert s0 != s1
445445
ray.get(lb.release_server.remote(server_id=s0))
446446
ray.get(lb.release_server.remote(server_id=s1))
447-
s2 = ray.get(lb.acquire_server.remote(request_id="r2"))
448-
s3 = ray.get(lb.acquire_server.remote(request_id="r3"))
447+
s2 = ray.get(lb.acquire_server.remote(request_id="r2"))[0]
448+
s3 = ray.get(lb.acquire_server.remote(request_id="r3"))[0]
449449
assert s2 != s3
450450

451-
def test_release_invalid_server_raises(self, ray_for_lb):
451+
def test_release_invalid_server_silently_ignored(self, ray_for_lb):
452+
"""Releasing a nonexistent server is silently ignored (hybrid-safe)."""
452453
lb = GlobalRequestLoadBalancer.remote(servers={"s0": None, "s1": None})
453-
with pytest.raises(ray.exceptions.RayTaskError, match="Invalid server_id") as excinfo:
454-
ray.get(lb.release_server.remote(server_id="nonexistent"))
455-
assert "Invalid server_id" in str(excinfo.value)
454+
# Should not raise
455+
ray.get(lb.release_server.remote(server_id="nonexistent"))
456456

457-
def test_release_without_inflight_raises(self, ray_for_lb):
457+
def test_release_without_inflight_silently_ignored(self, ray_for_lb):
458+
"""Releasing a server with no inflight requests is silently ignored (hybrid-safe)."""
458459
lb = GlobalRequestLoadBalancer.remote(servers={"s0": None, "s1": None})
459-
with pytest.raises(ray.exceptions.RayTaskError, match="no inflight") as excinfo:
460-
ray.get(lb.release_server.remote(server_id="s1"))
461-
assert "no inflight" in str(excinfo.value)
460+
# Should not raise even though s1 has 0 inflight
461+
ray.get(lb.release_server.remote(server_id="s1"))
462462

463463

464464
class TestLoadBalancerStickySession:
465465
"""Request-level sticky session."""
466466

467467
def test_same_request_id_same_server(self, ray_for_lb):
468468
lb = GlobalRequestLoadBalancer.remote(servers={"s0": None, "s1": None, "s2": None, "s3": None})
469-
s0 = ray.get(lb.acquire_server.remote(request_id="conv-abc"))
469+
s0 = ray.get(lb.acquire_server.remote(request_id="conv-abc"))[0]
470470
ray.get(lb.release_server.remote(server_id=s0))
471-
s1 = ray.get(lb.acquire_server.remote(request_id="conv-abc"))
471+
s1 = ray.get(lb.acquire_server.remote(request_id="conv-abc"))[0]
472472
assert s0 == s1
473+
474+
475+
class TestLoadBalancerHybrid:
476+
"""Dynamic server add/remove for hybrid scaling."""
477+
478+
def test_add_server(self, ray_for_lb):
479+
lb = GlobalRequestLoadBalancer.remote(servers={"s0": None, "s1": None})
480+
ray.get(lb.add_servers.remote(servers={"s2": None}))
481+
status = ray.get(lb.get_status.remote())
482+
assert "s2" in status["servers"]
483+
assert status["servers"]["s2"] == 0
484+
485+
def test_remove_server_purges_handle(self, ray_for_lb):
486+
lb = GlobalRequestLoadBalancer.remote(servers={"s0": None, "s1": None})
487+
ray.get(lb.remove_servers.remote(server_ids=["s1"]))
488+
# remove_server now purges from both _inflight_requests and _servers
489+
status = ray.get(lb.get_status.remote())
490+
assert "s1" not in status["servers"]
491+
assert "s1" not in status["registered_handles"]
492+
# New requests should only go to s0
493+
s = ray.get(lb.acquire_server.remote(request_id="r1"))[0]
494+
assert s == "s0"
495+
496+
def test_removed_server_invalidates_sticky_session(self, ray_for_lb):
497+
"""When a sticky session points to a removed server, cache is invalidated."""
498+
lb = GlobalRequestLoadBalancer.remote(servers={"s0": None, "s1": None})
499+
# Occupy s0 so that the sticky request is assigned to s1
500+
ray.get(lb.acquire_server.remote(request_id="occupy-s0"))[0] # -> s0
501+
# Pin request to s1 (least-loaded now)
502+
s1 = ray.get(lb.acquire_server.remote(request_id="sticky-req"))[0]
503+
assert s1 == "s1"
504+
ray.get(lb.release_server.remote(server_id=s1))
505+
# Remove s1
506+
ray.get(lb.remove_servers.remote(server_ids=["s1"]))
507+
# Sticky session should be invalidated and reroute to s0
508+
s_new = ray.get(lb.acquire_server.remote(request_id="sticky-req"))[0]
509+
assert s_new == "s0"
510+
511+
def test_remove_server_also_purges_registry(self, ray_for_lb):
512+
"""remove_servers atomically purges from both LB pool and handle registry."""
513+
lb = GlobalRequestLoadBalancer.remote(servers={"s0": None, "s1": None})
514+
ray.get(lb.remove_servers.remote(server_ids=["s1"]))
515+
status = ray.get(lb.get_status.remote())
516+
# Both _inflight_requests and _servers are cleaned up (no separate cleanup step needed)
517+
assert "s1" not in status["servers"]
518+
assert "s1" not in status["registered_handles"]
519+
520+
def test_get_all_servers_excludes_removed(self, ray_for_lb):
521+
lb = GlobalRequestLoadBalancer.remote(servers={"s0": None, "s1": None, "s2": None})
522+
ray.get(lb.remove_servers.remote(server_ids=["s1"]))
523+
all_servers = ray.get(lb.get_all_servers.remote())
524+
assert "s0" in all_servers
525+
assert "s2" in all_servers
526+
assert "s1" not in all_servers
527+
528+
def test_no_available_servers_raises(self, ray_for_lb):
529+
lb = GlobalRequestLoadBalancer.remote(servers={"s0": None, "s1": None})
530+
ray.get(lb.remove_servers.remote(server_ids=["s0", "s1"]))
531+
with pytest.raises(ray.exceptions.RayTaskError, match="No available servers"):
532+
ray.get(lb.acquire_server.remote(request_id="r1"))
533+
534+
def test_add_server_readds_previously_removed(self, ray_for_lb):
535+
"""Re-adding a previously removed server makes it routable again."""
536+
lb = GlobalRequestLoadBalancer.remote(servers={"s0": None, "s1": None})
537+
ray.get(lb.remove_servers.remote(server_ids=["s1"]))
538+
# s1 is removed, only s0 is available
539+
assert ray.get(lb.acquire_server.remote(request_id="r1"))[0] == "s0"
540+
# Re-add s1
541+
ray.get(lb.add_servers.remote(servers={"s1": None}))
542+
# Now both s0 and s1 should be available
543+
s = ray.get(lb.acquire_server.remote(request_id="r2"))[0]
544+
assert s in ("s0", "s1")
545+
546+
def test_get_inflight_count(self, ray_for_lb):
547+
lb = GlobalRequestLoadBalancer.remote(servers={"s0": None, "s1": None})
548+
assert ray.get(lb.get_inflight_count.remote(server_id="s0")) == 0
549+
ray.get(lb.acquire_server.remote(request_id="r1"))[0] # -> s0 (least loaded)
550+
assert ray.get(lb.get_inflight_count.remote(server_id="s0")) == 1
551+
552+
def test_get_status_reports_active_correctly(self, ray_for_lb):
553+
lb = GlobalRequestLoadBalancer.remote(servers={"s0": None, "s1": None, "s2": None})
554+
ray.get(lb.remove_servers.remote(server_ids=["s1"]))
555+
status = ray.get(lb.get_status.remote())
556+
assert status["active_servers"] == 2 # s0 and s2
557+
assert status["total_inflight"] == 0

verl/checkpoint_engine/base.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,34 @@ async def wake_up_replicas(self):
416416
"""Resume all rollout replicas: recover kv_cache and weights device memory."""
417417
await asyncio.gather(*[r.wake_up() for r in self.replicas])
418418

419+
@auto_await
420+
async def abort_replicas(self):
421+
"""Abort all in-flight requests on every replica."""
422+
await asyncio.gather(*[r.abort_all_requests() for r in self.replicas])
423+
424+
@auto_await
425+
async def resume_generation_replicas(self):
426+
"""Resume generation on all replicas after abort_all_requests."""
427+
await asyncio.gather(*[r.resume_generation() for r in self.replicas])
428+
429+
@auto_await
430+
async def release_kv_cache_replicas(self):
431+
"""Release kv_cache of all rollout replicas before NCCL weight sync.
432+
433+
Unlike sleep_replicas(), this only frees the kv_cache and leaves model
434+
weights untouched, so the NCCL transfer can write directly into the
435+
existing weight buffers. Call resume_kv_cache_replicas() after sync.
436+
"""
437+
await asyncio.gather(*[r.release_kv_cache() for r in self.replicas])
438+
439+
@auto_await
440+
async def resume_kv_cache_replicas(self):
441+
"""Restore kv_cache of all rollout replicas after NCCL weight sync.
442+
443+
Counterpart to release_kv_cache_replicas().
444+
"""
445+
await asyncio.gather(*[r.resume_kv_cache() for r in self.replicas])
446+
419447
@auto_await
420448
async def update_weights(self, global_steps: int = None):
421449
"""Update weights from trainer to rollout replicas.
@@ -426,11 +454,11 @@ async def update_weights(self, global_steps: int = None):
426454

427455
# 0. update weights for sync training with colocated trainer and rollout
428456
if self.backend == "naive":
429-
ray.get(self.trainer.update_weights(global_steps=global_steps))
457+
ray.get(self.trainer.update_weights(global_steps=global_steps, mode=self.backend))
430458
return
431459

432460
# 1. abort and save all unfinished requests for partial rollout
433-
await asyncio.gather(*[r.abort_all_requests() for r in self.replicas])
461+
await self.abort_replicas()
434462

435463
# 2. create a temporay worker group for all replicas
436464
workers = []
@@ -439,26 +467,29 @@ async def update_weights(self, global_steps: int = None):
439467
rollout = RayWorkerGroup(worker_handles=workers, ray_cls_with_init=RayClassWithInitArgs(cls=_worker_cls))
440468
trainer = self.trainer
441469

442-
# 3. sleep replicas to free kv_cache before weight sync (if free_cache_engine is enabled)
443-
await self.sleep_replicas()
470+
# 3. release kv_cache before weight sync (weights stay in place)
471+
await self.release_kv_cache_replicas()
444472

445473
# 4. build process group
446474
self.build_process_group(rollout)
447475

448476
# 5. update weights of all workers
449-
ray.get(trainer.update_weights(global_steps=global_steps) + rollout.update_weights(global_steps=global_steps))
477+
ray.get(
478+
trainer.update_weights(global_steps=global_steps, mode=self.backend)
479+
+ rollout.update_weights(global_steps=global_steps)
480+
)
450481

451482
# 6. finalize all workers
452483
ray.get(
453484
trainer.execute_checkpoint_engine(["finalize"] * trainer.world_size)
454485
+ rollout.execute_checkpoint_engine(["finalize"] * rollout.world_size)
455486
)
456487

457-
# 7. resume replicas to recover kv_cache (for free_cache_engine scenarios)
458-
await self.wake_up_replicas()
488+
# 7. restore kv_cache after weight sync
489+
await self.resume_kv_cache_replicas()
459490

460491
# 8. resume all unfinished requests for partial rollout
461-
await asyncio.gather(*[r.resume_generation() for r in self.replicas])
492+
await self.resume_generation_replicas()
462493

463494

464495
async def split_weight_chunks(

verl/experimental/fully_async_policy/detach_utils.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import time
1616
from collections import defaultdict
1717
from dataclasses import dataclass
18-
from typing import Any, Optional
18+
from typing import Any
1919

2020
import numpy as np
2121
import torch
@@ -39,15 +39,6 @@ class RolloutSample:
3939
rollout_status: dict[str, Any]
4040

4141

42-
@dataclass
43-
class ValidateMetrics:
44-
"""Metrics for validation"""
45-
46-
timing_raw: dict[str, Any]
47-
metrics: Optional[dict[str, Any]] = None
48-
val_generations: Optional[list[tuple]] = None
49-
50-
5142
def prepare_single_generation_data(batch_dict, config) -> DataProto:
5243
"""
5344
Similar to the logic of ray_trainer._prepare_generate_batch, but for a single sample.

0 commit comments

Comments
 (0)