Skip to content

Commit 083a8ad

Browse files
authored
Refine rollout worker health check and recovery lifecycle (#1877)
* Refactor rollout server launch specs * Refactor rollout health manager * Restore rollout worker session URLs * fix * fix * fix * fix comments * Delay rollout health checks after resume and log LMDeploy health errors * fix ci * delete rank2info and worker_server_urls_map * Require fail-fast health checks before rollout offload * Align rollout generate concurrency with request entrypoint topology * Refactor colocate rollout recovery with fail-fast shutdown and pre-sync restart * simplify code * fix ci * fix ci and comments * Refactor rollout worker state into registry
1 parent b2c8c6c commit 083a8ad

12 files changed

Lines changed: 1780 additions & 746 deletions

tests/rl/test_rl_colocate_trainer.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ def __init__(self, uid: int):
5252
self.extra_fields = {}
5353
self.response_model_steps = []
5454

55-
5655
class _FakeSampler:
5756
def __init__(self):
5857
self._next_id = 0
@@ -148,13 +147,18 @@ def _make_trainer(self, agent_loop_manager, *, total_train_steps: int = 1, sync_
148147
)
149148

150149
trainer.rollout_controller = SimpleNamespace(
151-
ensure_workers_healthy_before_training=SimpleNamespace(
152-
remote=MagicMock(return_value="rollout_ready_for_training")
150+
check_and_shutdown_inactive_workers=SimpleNamespace(
151+
remote=MagicMock(return_value="rollout_inactive_workers_shutdown")
153152
),
154153
offload=SimpleNamespace(remote=MagicMock(return_value="rollout_offloaded")),
154+
restart_inactive_workers=SimpleNamespace(remote=MagicMock(return_value="rollout_restarted")),
155+
onload_weights=SimpleNamespace(remote=MagicMock(return_value="weights_loaded")),
156+
onload_kvcache=SimpleNamespace(remote=MagicMock(return_value="kvcache_loaded")),
155157
)
156158
trainer.train_controller = SimpleNamespace(
157159
onload=MagicMock(return_value="train_onloaded"),
160+
offload=MagicMock(return_value="train_offloaded"),
161+
update_weights=MagicMock(return_value="weights_updated"),
158162
fit=MagicMock(
159163
return_value=[
160164
{
@@ -220,15 +224,37 @@ async def _produce_empty(batch_size, train_step, **kwargs):
220224
trainer.train_controller.fit.assert_not_called()
221225
self.assertEqual(trainer._cur_step, 0)
222226

227+
def test_fit_does_not_onload_train_when_rollout_training_barrier_fails(self):
228+
# 验证共卡训练进入训练前必须先通过 rollout phase-switch barrier;
229+
# 失败时不能 onload 训练。
230+
async def _produce_batch(batch_size, train_step, *, model_step):
231+
return ProduceBatchResult(rollout_states=[[_FakeRolloutState(train_step)]])
232+
233+
trainer = self._make_trainer(SimpleNamespace(produce_batch=_produce_batch))
234+
trainer.rollout_controller.check_and_shutdown_inactive_workers.remote.side_effect = RuntimeError(
235+
"inactive rollout workers after recovery"
236+
)
237+
238+
with (
239+
patch("xtuner.v1.train.rl_trainer.asyncio_run", side_effect=asyncio.run),
240+
patch("xtuner.v1.train.rl_trainer.ray.get", side_effect=lambda obj, timeout=None: obj),
241+
):
242+
with self.assertRaisesRegex(RuntimeError, "inactive rollout workers"):
243+
trainer.fit()
244+
245+
trainer.rollout_controller.check_and_shutdown_inactive_workers.remote.assert_called_once_with()
246+
trainer.rollout_controller.offload.remote.assert_not_called()
247+
trainer.train_controller.onload.assert_not_called()
248+
trainer.train_controller.fit.assert_not_called()
249+
self.assertEqual(trainer._cur_step, 0)
250+
223251
def test_fit_uses_sync_interval_and_passes_rollout_model_step(self):
224252
# 验证 rollout 看到的是按 sync interval 推进后的 model_step。
225253
produce_calls = []
226254

227255
async def _produce_batch(batch_size, train_step, *, model_step):
228256
produce_calls.append((batch_size, train_step, model_step))
229-
return ProduceBatchResult(
230-
rollout_states=[[SimpleNamespace(group_id=train_step, rollout_id=train_step)]]
231-
)
257+
return ProduceBatchResult(rollout_states=[[_FakeRolloutState(train_step)]])
232258

233259
trainer = self._make_trainer(
234260
SimpleNamespace(produce_batch=_produce_batch),

tests/rl/test_rl_disaggregated_trainer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,10 @@ def _make_trainer(self, agent_loop_manager):
146146
update_weights=MagicMock(return_value="update"),
147147
)
148148
trainer.rollout_controller = SimpleNamespace(
149-
recover_failed_workers=SimpleNamespace(remote=MagicMock(return_value="recover")),
149+
check_and_shutdown_inactive_workers=SimpleNamespace(
150+
remote=MagicMock(return_value="rollout_inactive_workers_shutdown")
151+
),
152+
restart_inactive_workers=SimpleNamespace(remote=MagicMock(return_value="rollout_restarted")),
150153
pause_generation=SimpleNamespace(remote=MagicMock(return_value="pause")),
151154
continue_generation=SimpleNamespace(remote=MagicMock(return_value="continue")),
152155
onload_weights=SimpleNamespace(remote=MagicMock(return_value="onload_weights")),

tests/rl/test_rl_trainer_checkpoint.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ def __init__(self):
9090
self.pause_generation = _RemoteMethod(async_result=True)
9191
self.continue_generation = _RemoteMethod(async_result=True)
9292
self.offload = _RemoteMethod(return_value="rollout_offloaded")
93-
self.ensure_workers_healthy_before_training = _RemoteMethod(return_value="rollout_ready_for_training")
94-
self.recover_failed_workers = _RemoteMethod(return_value="rollout_recovered")
93+
self.check_and_shutdown_inactive_workers = _RemoteMethod(return_value="rollout_inactive_workers_shutdown")
94+
self.restart_inactive_workers = _RemoteMethod(return_value="rollout_restarted")
9595
self.onload_weights = _RemoteMethod(return_value="weights_loaded")
9696
self.onload_kvcache = _RemoteMethod(return_value="kvcache_loaded")
9797
self.get_rollout_metadata = _RemoteMethod(return_value={"server_url_dict": {}})
@@ -204,6 +204,7 @@ def build_rollout_controller(rollout_cfg, placement_group):
204204
return controller
205205

206206
with (
207+
patch("ray.get", side_effect=lambda obj, timeout=None: obj),
207208
patch("xtuner.v1.rl.utils.ray_accelerator_worker.ray.is_initialized", return_value=True),
208209
patch(
209210
"xtuner.v1.rl.utils.ray_accelerator_worker.ray.available_resources",
@@ -217,6 +218,12 @@ def build_rollout_controller(rollout_cfg, placement_group):
217218
patch("xtuner.v1.train.rl_trainer.BaseRLTrainer._release_trace_store", return_value=None),
218219
patch.object(WorkerConfig, "build", autospec=True, side_effect=build_train_controller),
219220
patch.object(RolloutConfig, "build", autospec=True, side_effect=build_rollout_controller),
221+
patch.object(
222+
RolloutConfig,
223+
"get_controller_generate_concurrency",
224+
autospec=True,
225+
side_effect=lambda rollout_cfg, placement_group: rollout_cfg.generate_concurrency_per_instance,
226+
),
220227
):
221228
yield runtime
222229

@@ -321,6 +328,7 @@ def _build_colocate_config(
321328
auto_resume=auto_resume,
322329
checkpoint_interval=1,
323330
checkpoint_maxkeep=None,
331+
checkpoint_no_save_replay_buffer=True,
324332
hf_interval=-1,
325333
seed=42,
326334
exp_tracker="jsonl",
@@ -361,6 +369,7 @@ def _build_disaggregated_config(
361369
auto_resume=auto_resume,
362370
checkpoint_interval=1,
363371
checkpoint_maxkeep=None,
372+
checkpoint_no_save_replay_buffer=True,
364373
hf_interval=-1,
365374
seed=42,
366375
exp_tracker="jsonl",

0 commit comments

Comments
 (0)