@@ -178,6 +178,7 @@ class AgentInSandboxLoopConfig(AgentLoopConfig):
178178 max_concurrent_samples : int | None = None
179179 mode : Literal ["train" , "eval" ] = "train"
180180 requires_rollout_proxy : bool = True
181+ process_advantage_builder : str | None = None
181182
182183 def build_local (
183184 self , rollout_controller : RolloutController | None = None , judger : Judger | None = None , logger = None
@@ -190,6 +191,7 @@ def build_local(
190191 logger = logger ,
191192 max_concurrent_samples = self .max_concurrent_samples ,
192193 mode = self .mode ,
194+ process_advantage_builder = self .process_advantage_builder ,
193195 )
194196
195197
@@ -203,13 +205,17 @@ def __init__(
203205 logger = None ,
204206 max_concurrent_samples : int | None = None ,
205207 mode : Literal ["train" , "eval" ] = "train" ,
208+ process_advantage_builder : str | None = None ,
206209 ):
207210 if hf_checkpoint is None :
208211 raise ValueError ("hf_checkpoint must be provided for AgentInSandboxLoop." )
209212 super ().__init__ (rollout_ctl , sample_params , hf_checkpoint , judger , logger )
210213 self .max_concurrent_samples = max_concurrent_samples
211214 self ._sample_semaphore = asyncio .Semaphore (max_concurrent_samples ) if max_concurrent_samples else None
212215 self .mode = mode
216+ self .process_advantage_builder = (
217+ _import_from_path (process_advantage_builder ) if process_advantage_builder is not None else None
218+ )
213219
214220 async def generate_group (self , rollout_state : list [RolloutState ], ** kwargs ) -> list [RolloutState ]:
215221 async def generate_one (state : RolloutState ) -> list [RolloutState ]:
@@ -313,6 +319,16 @@ async def _build_rollout_states(self, rollout_state: RolloutState, item: AgentRo
313319 data = await trace_store .export_training_trace .remote (str (rollout_state .session_id ), prompt_text )
314320 segment_state .input_ids = data ["input_ids" ]
315321 segment_state .labels = data ["labels" ]
322+ segment_state .extra_fields ["agent_trace_segments" ] = data .get ("segments" , [])
323+ if self .process_advantage_builder is not None :
324+ segment_state .advantage_weight , process_adv_summary = self .process_advantage_builder (
325+ messages ,
326+ data ["labels" ],
327+ data .get ("segments" ),
328+ )
329+ segment_state .extra_fields ["process_adv" ] = process_adv_summary
330+ else :
331+ segment_state .advantage_weight = None
316332 # Agentic training consumes input_ids/labels directly. response_ids is
317333 # filled here only so rollout throughput logging can print rollout_tgs.
318334 segment_state .response_ids = [
@@ -341,6 +357,7 @@ def _fill_eval_rollout_state(self, rollout_state: RolloutState, item: AgentRollo
341357 rollout_state .routed_experts = None
342358 rollout_state .response_mask = None
343359 rollout_state .response_model_steps = None
360+ rollout_state .advantage_weight = None
344361 rollout_state .extra_fields ["agent_status" ] = item .status .value
345362 selected_agent = _selected_agent (item )
346363 if selected_agent is not None :
0 commit comments