Skip to content

Commit c2231af

Browse files
authored
[Refactor] Seperate colocate and disaggregated produce flow (#1884)
1 parent a4bea0d commit c2231af

20 files changed

Lines changed: 4481 additions & 1642 deletions

docs/design/sep_code.md

Lines changed: 492 additions & 0 deletions
Large diffs are not rendered by default.

docs/design/sep_code_demo.py

Lines changed: 1472 additions & 0 deletions
Large diffs are not rendered by default.

examples/v1/config/rl_disagg_multi.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,16 @@
3535
from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig
3636
from xtuner.v1.rl.agent_loop_manager import (
3737
AgentLoopManagerConfig,
38-
AsyncProduceStrategyConfig,
38+
DisaggAsyncProduceStrategyConfig,
39+
DisaggAgentLoopManagerConfig,
40+
DisaggTaskSpecConfig,
3941
SamplerConfig,
40-
SyncProduceStrategyConfig,
4142
TaskSpecConfig,
4243
)
4344
from xtuner.v1.rl.evaluator import EvaluatorConfig
4445
from xtuner.v1.rl.judger import DapoMathJudgerConfig
4546
from xtuner.v1.rl.loss import GRPOLossConfig
46-
from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig
47+
from xtuner.v1.rl.replay_buffer import AsyncReplayBufferConfig
4748
from xtuner.v1.rl.rollout.worker import RolloutConfig
4849
from xtuner.v1.rl.trainer import WorkerConfig
4950
from xtuner.v1.rl.utils import AcceleratorResourcesConfig, get_eos_token
@@ -221,27 +222,25 @@
221222
),
222223
)
223224

224-
if over_sample_threshold > 0 or partial_rollout:
225-
produce_strategy_config = AsyncProduceStrategyConfig(
226-
over_sample_threshold=over_sample_threshold,
227-
enable_partial_rollout=partial_rollout,
228-
tail_batch_trigger_size=tail_batch_trigger_size,
229-
max_staleness=max_staleness,
230-
)
231-
else:
232-
produce_strategy_config = SyncProduceStrategyConfig()
233-
234-
agent_loop_manager_cfg = AgentLoopManagerConfig(
225+
# 非共卡后台 producer 使用独立的 Disagg* config,不复用共卡 AsyncProduceStrategyConfig。
226+
produce_strategy_config = DisaggAsyncProduceStrategyConfig(
227+
over_sample_threshold=over_sample_threshold,
228+
enable_partial_rollout=partial_rollout,
229+
tail_batch_trigger_size=tail_batch_trigger_size,
230+
max_staleness=max_staleness,
231+
)
232+
233+
agent_loop_manager_cfg = DisaggAgentLoopManagerConfig(
235234
tasks=[
236-
TaskSpecConfig(
235+
DisaggTaskSpecConfig(
237236
task_name="train_task:dapo_math",
238237
weight=dapo_task_weight,
239238
agent_loop_config=dapo_train_agent_loop_config,
240239
judger_config=judger_config,
241240
produce_strategy_config=produce_strategy_config,
242241
sampler_config=dapo_train_sampler_config,
243242
),
244-
TaskSpecConfig(
243+
DisaggTaskSpecConfig(
245244
task_name="train_task:gsm8k",
246245
weight=gsm8k_task_weight,
247246
agent_loop_config=gsm8k_train_agent_loop_config,
@@ -335,7 +334,7 @@ def compute_metric(samples):
335334
train_worker_cfg=train_worker_cfg,
336335
rollout_config=rollout_config,
337336
tokenizer_path=model_path,
338-
replay_buffer_config=SyncReplayBufferConfig(),
337+
replay_buffer_config=AsyncReplayBufferConfig(),
339338
agent_loop_manager_cfg=agent_loop_manager_cfg,
340339
eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg,
341340
evaluator_config=evaluator_config,

examples/v1/config/rl_disagg_single.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,16 @@
4545
from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig
4646
from xtuner.v1.rl.agent_loop_manager import (
4747
AgentLoopManagerConfig,
48-
AsyncProduceStrategyConfig,
48+
DisaggAsyncProduceStrategyConfig,
49+
DisaggAgentLoopManagerConfig,
50+
DisaggTaskSpecConfig,
4951
SamplerConfig,
50-
SyncProduceStrategyConfig,
5152
TaskSpecConfig,
5253
)
5354
from xtuner.v1.rl.evaluator import EvaluatorConfig
5455
from xtuner.v1.rl.judger import GSM8KJudgerConfig
5556
from xtuner.v1.rl.loss import GRPOLossConfig
56-
from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig
57+
from xtuner.v1.rl.replay_buffer import AsyncReplayBufferConfig
5758
from xtuner.v1.rl.rollout.worker import RolloutConfig
5859
from xtuner.v1.rl.trainer import WorkerConfig
5960
from xtuner.v1.rl.utils import AcceleratorResourcesConfig
@@ -193,17 +194,15 @@
193194
hf_checkpoint=model_path,
194195
sample_params=training_sample_params,
195196
)
196-
if over_sample_threshold > 0 or partial_rollout:
197-
produce_strategy_config = AsyncProduceStrategyConfig(
198-
over_sample_threshold=over_sample_threshold,
199-
enable_partial_rollout=partial_rollout,
200-
tail_batch_trigger_size=tail_batch_trigger_size,
201-
max_staleness=max_staleness,
202-
)
203-
else:
204-
produce_strategy_config = SyncProduceStrategyConfig()
205-
agent_loop_manager_cfg = AgentLoopManagerConfig(
206-
tasks=TaskSpecConfig(
197+
# 非共卡后台 producer 使用独立的 Disagg* config,不复用共卡 AsyncProduceStrategyConfig。
198+
produce_strategy_config = DisaggAsyncProduceStrategyConfig(
199+
over_sample_threshold=over_sample_threshold,
200+
enable_partial_rollout=partial_rollout,
201+
tail_batch_trigger_size=tail_batch_trigger_size,
202+
max_staleness=max_staleness,
203+
)
204+
agent_loop_manager_cfg = DisaggAgentLoopManagerConfig(
205+
tasks=DisaggTaskSpecConfig(
207206
task_name="train_task",
208207
agent_loop_config=agent_loop_config,
209208
judger_config=judger_config,
@@ -258,7 +257,7 @@
258257
train_worker_cfg=train_worker_cfg,
259258
rollout_config=rollout_config,
260259
tokenizer_path=model_path,
261-
replay_buffer_config=SyncReplayBufferConfig(),
260+
replay_buffer_config=AsyncReplayBufferConfig(),
262261
agent_loop_manager_cfg=agent_loop_manager_cfg,
263262
eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg,
264263
evaluator_config=evaluator_config,

0 commit comments

Comments
 (0)