|
35 | 35 | from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig |
36 | 36 | from xtuner.v1.rl.agent_loop_manager import ( |
37 | 37 | AgentLoopManagerConfig, |
38 | | - AsyncProduceStrategyConfig, |
| 38 | + DisaggAsyncProduceStrategyConfig, |
| 39 | + DisaggAgentLoopManagerConfig, |
| 40 | + DisaggTaskSpecConfig, |
39 | 41 | SamplerConfig, |
40 | | - SyncProduceStrategyConfig, |
41 | 42 | TaskSpecConfig, |
42 | 43 | ) |
43 | 44 | from xtuner.v1.rl.evaluator import EvaluatorConfig |
44 | 45 | from xtuner.v1.rl.judger import DapoMathJudgerConfig |
45 | 46 | from xtuner.v1.rl.loss import GRPOLossConfig |
46 | | -from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig |
| 47 | +from xtuner.v1.rl.replay_buffer import AsyncReplayBufferConfig |
47 | 48 | from xtuner.v1.rl.rollout.worker import RolloutConfig |
48 | 49 | from xtuner.v1.rl.trainer import WorkerConfig |
49 | 50 | from xtuner.v1.rl.utils import AcceleratorResourcesConfig, get_eos_token |
|
221 | 222 | ), |
222 | 223 | ) |
223 | 224 |
|
224 | | -if over_sample_threshold > 0 or partial_rollout: |
225 | | - produce_strategy_config = AsyncProduceStrategyConfig( |
226 | | - over_sample_threshold=over_sample_threshold, |
227 | | - enable_partial_rollout=partial_rollout, |
228 | | - tail_batch_trigger_size=tail_batch_trigger_size, |
229 | | - max_staleness=max_staleness, |
230 | | - ) |
231 | | -else: |
232 | | - produce_strategy_config = SyncProduceStrategyConfig() |
233 | | - |
234 | | -agent_loop_manager_cfg = AgentLoopManagerConfig( |
| 225 | +# 非共卡后台 producer 使用独立的 Disagg* config,不复用共卡 AsyncProduceStrategyConfig。 |
| 226 | +produce_strategy_config = DisaggAsyncProduceStrategyConfig( |
| 227 | + over_sample_threshold=over_sample_threshold, |
| 228 | + enable_partial_rollout=partial_rollout, |
| 229 | + tail_batch_trigger_size=tail_batch_trigger_size, |
| 230 | + max_staleness=max_staleness, |
| 231 | +) |
| 232 | + |
| 233 | +agent_loop_manager_cfg = DisaggAgentLoopManagerConfig( |
235 | 234 | tasks=[ |
236 | | - TaskSpecConfig( |
| 235 | + DisaggTaskSpecConfig( |
237 | 236 | task_name="train_task:dapo_math", |
238 | 237 | weight=dapo_task_weight, |
239 | 238 | agent_loop_config=dapo_train_agent_loop_config, |
240 | 239 | judger_config=judger_config, |
241 | 240 | produce_strategy_config=produce_strategy_config, |
242 | 241 | sampler_config=dapo_train_sampler_config, |
243 | 242 | ), |
244 | | - TaskSpecConfig( |
| 243 | + DisaggTaskSpecConfig( |
245 | 244 | task_name="train_task:gsm8k", |
246 | 245 | weight=gsm8k_task_weight, |
247 | 246 | agent_loop_config=gsm8k_train_agent_loop_config, |
@@ -335,7 +334,7 @@ def compute_metric(samples): |
335 | 334 | train_worker_cfg=train_worker_cfg, |
336 | 335 | rollout_config=rollout_config, |
337 | 336 | tokenizer_path=model_path, |
338 | | - replay_buffer_config=SyncReplayBufferConfig(), |
| 337 | + replay_buffer_config=AsyncReplayBufferConfig(), |
339 | 338 | agent_loop_manager_cfg=agent_loop_manager_cfg, |
340 | 339 | eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg, |
341 | 340 | evaluator_config=evaluator_config, |
|
0 commit comments