Skip to content

Commit 0767d9e

Browse files
committed
Add process advantage weighting for agent rollouts
1 parent e5cdbfb commit 0767d9e

10 files changed

Lines changed: 212 additions & 4 deletions

File tree

xtuner/v1/data_proto/rl_data.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,10 @@ class RolloutState(BaseModel):
120120

121121
input_ids: list[int] | None = None
122122
labels: list[int] | None = None
123+
# Per-token multiplier applied to positive advantages after outcome reward
124+
# advantage estimation. Coordinates match input_ids / labels; trainer uses
125+
# advantage_weight[1:] to align with shifted_labels.
126+
advantage_weight: list[float] | None = None
123127

124128
# --- Judger 输出 ---
125129
reward: dict[str, Any] | None = None
@@ -248,6 +252,7 @@ def reset_rollout_response(rollout_state: RolloutState) -> RolloutState:
248252
rollout_state.finish_reason = None
249253
rollout_state.response_mask = []
250254
rollout_state.response_model_steps = []
255+
rollout_state.advantage_weight = None
251256
rollout_state.reward = None
252257
rollout_state.error_msg = None
253258
return rollout_state

xtuner/v1/rl/advantage/base.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,5 +58,23 @@ def compute(self, rewards: torch.Tensor, group: list[Any]) -> torch.Tensor:
5858
"""
5959
...
6060

61+
def expand_to_token_advantages(
62+
self,
63+
*,
64+
base_advantage: float,
65+
rollout_state: Any,
66+
shifted_labels: list[int],
67+
shifted_advantage_weight: list[float] | None = None,
68+
) -> tuple[list[float], dict[str, Any]]:
69+
"""Expand a sample-level advantage to token-level advantages.
70+
71+
``compute`` intentionally stays sample/session-level. This hook lets
72+
downstream projects shape token credit after labels and optional
73+
per-token weights are known by the trainer.
74+
"""
75+
76+
del rollout_state, shifted_advantage_weight
77+
return [0.0 if label == -100 else base_advantage for label in shifted_labels], {}
78+
6179
def __repr__(self) -> str:
6280
return f"{self.__class__.__name__}()"

xtuner/v1/rl/agent_loop/localhost_agent_loop/agent_in_localhost_loop.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class AgentInLocalhostLoopConfig(AgentLoopConfig):
8585
sample_timeout_s: float | None = None
8686
mode: Literal["train", "eval"] = "train"
8787
requires_rollout_proxy: bool = True
88+
process_advantage_builder: str | None = None
8889

8990
def build_local(
9091
self,
@@ -101,6 +102,7 @@ def build_local(
101102
max_concurrent_samples=self.max_concurrent_samples,
102103
sample_timeout_s=self.sample_timeout_s,
103104
mode=self.mode,
105+
process_advantage_builder=self.process_advantage_builder,
104106
)
105107

106108

@@ -117,6 +119,7 @@ def __init__(
117119
max_concurrent_samples: int | None = None,
118120
sample_timeout_s: float | None = None,
119121
mode: Literal["train", "eval"] = "train",
122+
process_advantage_builder: str | None = None,
120123
):
121124
if hf_checkpoint is None:
122125
raise ValueError("hf_checkpoint must be provided for AgentInLocalhostLoop.")
@@ -125,6 +128,9 @@ def __init__(
125128
self.sample_timeout_s = sample_timeout_s
126129
self._sample_semaphore = asyncio.Semaphore(max_concurrent_samples) if max_concurrent_samples else None
127130
self.mode = mode
131+
self.process_advantage_builder = (
132+
_import_from_path(process_advantage_builder) if process_advantage_builder is not None else None
133+
)
128134

129135
async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]:
130136
async def generate_one(state: RolloutState) -> RolloutState:
@@ -246,6 +252,16 @@ async def _fill_rollout_state(self, rollout_state: RolloutState, item: AgentRoll
246252

247253
rollout_state.input_ids = data["input_ids"]
248254
rollout_state.labels = data["labels"]
255+
rollout_state.extra_fields["agent_trace_segments"] = data.get("segments", [])
256+
if self.process_advantage_builder is not None:
257+
rollout_state.advantage_weight, process_adv_summary = self.process_advantage_builder(
258+
segment["messages"],
259+
data["labels"],
260+
data.get("segments"),
261+
)
262+
rollout_state.extra_fields["process_adv"] = process_adv_summary
263+
else:
264+
rollout_state.advantage_weight = None
249265
rollout_state.response_ids = [
250266
token_id for token_id, label in zip(data["input_ids"][1:], data["labels"][1:]) if label != -100
251267
]
@@ -267,6 +283,7 @@ def _fill_eval_rollout_state(self, rollout_state: RolloutState, item: AgentRollo
267283
rollout_state.routed_experts = None
268284
rollout_state.response_mask = None
269285
rollout_state.response_model_steps = None
286+
rollout_state.advantage_weight = None
270287
rollout_state.extra_fields["agent_status"] = item.status.value
271288
if item.error is not None:
272289
rollout_state.error_msg = f"{item.error.stage}/{item.error.category}: {item.error.message}"

xtuner/v1/rl/agent_loop/localhost_agent_loop/compose.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,13 @@ def __init__(
3535
async def run(self, item: AgentRolloutItem, record: StageRecord) -> float:
3636
record.status = StageStatus.RUNNING
3737
record.started_at = record.started_at or time.monotonic()
38+
record.judger_name = self.name
3839
try:
3940
weighted_score = 0.0
4041
total_weight = 0.0
4142
for stage in self.stages:
4243
name = getattr(stage, "name", stage.__class__.__name__)
43-
child_record = item.judgers.setdefault(name, StageRecord())
44+
child_record = item.judgers.setdefault(name, StageRecord(judger_name=name))
4445
score = float(await stage.run(item, child_record))
4546
stage_weight = max(float(getattr(stage, "weight", 1.0)), 0.0)
4647
weighted_score += score * stage_weight

xtuner/v1/rl/agent_loop/sandbox_agent_loop/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
AgentInSandboxLoop,
1212
AgentInSandboxLoopConfig,
1313
)
14+
from xtuner.v1.rl.agent_loop.sandbox_agent_loop.compose import SandboxComposeStage
1415
from xtuner.v1.rl.agent_loop.sandbox_agent_loop.hooks import (
1516
DownloadHook,
1617
ExecHook,
@@ -71,6 +72,7 @@
7172
"RunAgentInstallDeps",
7273
"Runner",
7374
"SandboxPool",
75+
"SandboxComposeStage",
7476
"SandboxSpec",
7577
"SandboxStage",
7678
"ShellEntry",

xtuner/v1/rl/agent_loop/sandbox_agent_loop/agent_in_sandbox_loop.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ class AgentInSandboxLoopConfig(AgentLoopConfig):
178178
max_concurrent_samples: int | None = None
179179
mode: Literal["train", "eval"] = "train"
180180
requires_rollout_proxy: bool = True
181+
process_advantage_builder: str | None = None
181182

182183
def build_local(
183184
self, rollout_controller: RolloutController | None = None, judger: Judger | None = None, logger=None
@@ -190,6 +191,7 @@ def build_local(
190191
logger=logger,
191192
max_concurrent_samples=self.max_concurrent_samples,
192193
mode=self.mode,
194+
process_advantage_builder=self.process_advantage_builder,
193195
)
194196

195197

@@ -203,13 +205,17 @@ def __init__(
203205
logger=None,
204206
max_concurrent_samples: int | None = None,
205207
mode: Literal["train", "eval"] = "train",
208+
process_advantage_builder: str | None = None,
206209
):
207210
if hf_checkpoint is None:
208211
raise ValueError("hf_checkpoint must be provided for AgentInSandboxLoop.")
209212
super().__init__(rollout_ctl, sample_params, hf_checkpoint, judger, logger)
210213
self.max_concurrent_samples = max_concurrent_samples
211214
self._sample_semaphore = asyncio.Semaphore(max_concurrent_samples) if max_concurrent_samples else None
212215
self.mode = mode
216+
self.process_advantage_builder = (
217+
_import_from_path(process_advantage_builder) if process_advantage_builder is not None else None
218+
)
213219

214220
async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]:
215221
async def generate_one(state: RolloutState) -> list[RolloutState]:
@@ -313,6 +319,16 @@ async def _build_rollout_states(self, rollout_state: RolloutState, item: AgentRo
313319
data = await trace_store.export_training_trace.remote(str(rollout_state.session_id), prompt_text)
314320
segment_state.input_ids = data["input_ids"]
315321
segment_state.labels = data["labels"]
322+
segment_state.extra_fields["agent_trace_segments"] = data.get("segments", [])
323+
if self.process_advantage_builder is not None:
324+
segment_state.advantage_weight, process_adv_summary = self.process_advantage_builder(
325+
messages,
326+
data["labels"],
327+
data.get("segments"),
328+
)
329+
segment_state.extra_fields["process_adv"] = process_adv_summary
330+
else:
331+
segment_state.advantage_weight = None
316332
# Agentic training consumes input_ids/labels directly. response_ids is
317333
# filled here only so rollout throughput logging can print rollout_tgs.
318334
segment_state.response_ids = [
@@ -341,6 +357,7 @@ def _fill_eval_rollout_state(self, rollout_state: RolloutState, item: AgentRollo
341357
rollout_state.routed_experts = None
342358
rollout_state.response_mask = None
343359
rollout_state.response_model_steps = None
360+
rollout_state.advantage_weight = None
344361
rollout_state.extra_fields["agent_status"] = item.status.value
345362
selected_agent = _selected_agent(item)
346363
if selected_agent is not None:
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
"""Composable sandbox validation stages."""
2+
3+
from __future__ import annotations
4+
5+
import time
6+
from typing import Any
7+
8+
from lagent.utils import create_object
9+
10+
from xtuner.v1.rl.agent_loop.sandbox_agent_loop.sandbox import SandboxPool
11+
from xtuner.v1.rl.agent_loop.sandbox_agent_loop.schemas import (
12+
AgentRolloutItem,
13+
RolloutError,
14+
StageRecord,
15+
StageStatus,
16+
)
17+
18+
19+
class SandboxComposeStage:
20+
"""Compose multiple sandbox validation stages behind ``run(...) -> float``.
21+
22+
Stages with ``weight=0`` still run, but do not contribute to the returned
23+
score. This is used for process-adv annotators that mutate rollout
24+
artifacts without changing outcome reward.
25+
"""
26+
27+
def __init__(
28+
self,
29+
stages: list[Any],
30+
*,
31+
name: str = "validate",
32+
weight: float = 1.0,
33+
):
34+
if not stages:
35+
raise ValueError("SandboxComposeStage.stages is empty")
36+
self.name = name
37+
self.stages = [create_object(stage) for stage in stages]
38+
self.weight = weight
39+
40+
async def run(self, item: AgentRolloutItem, pool: SandboxPool, record: StageRecord) -> float:
41+
record.status = StageStatus.RUNNING
42+
record.started_at = record.started_at or time.monotonic()
43+
record.judger_name = self.name
44+
try:
45+
weighted_score = 0.0
46+
total_weight = 0.0
47+
for stage in self.stages:
48+
name = getattr(stage, "name", stage.__class__.__name__)
49+
child_record = item.judgers.setdefault(name, StageRecord(judger_name=name))
50+
score = float(await stage.run(item, pool, child_record))
51+
stage_weight = max(float(getattr(stage, "weight", 1.0)), 0.0)
52+
weighted_score += score * stage_weight
53+
total_weight += stage_weight
54+
record.score = weighted_score / total_weight if total_weight > 0 else 0.0
55+
record.status = StageStatus.COMPLETED
56+
return record.score
57+
except Exception as exc:
58+
record.status = StageStatus.FAILED
59+
child_error = next(
60+
(child.error for child in item.judgers.values() if child.error is not None),
61+
None,
62+
)
63+
record.error = (
64+
record.error
65+
or child_error
66+
or RolloutError(
67+
stage=self.name,
68+
category="validate_failed",
69+
type=type(exc).__name__,
70+
message=str(exc),
71+
)
72+
)
73+
raise
74+
finally:
75+
record.finished_at = time.monotonic()
76+
77+
78+
__all__ = ["SandboxComposeStage"]

xtuner/v1/rl/rollout/chat_template.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55

66
_RAW_ARGUMENTS_KEY = "__xtuner_raw_arguments__"
7+
_PROCESS_ONLY_MESSAGE_KEYS = ("finish_reason", "metainfo")
78

89

910
def canonicalize_messages_for_chat_template(messages: list[dict]) -> list[dict]:
@@ -19,6 +20,8 @@ def canonicalize_messages_for_chat_template(messages: list[dict]) -> list[dict]:
1920

2021
messages = copy.deepcopy(messages)
2122
for message in messages:
23+
for key in _PROCESS_ONLY_MESSAGE_KEYS:
24+
message.pop(key, None)
2225
tool_calls = message.get("tool_calls")
2326
if not isinstance(tool_calls, list):
2427
continue

xtuner/v1/rl/rollout/trace_store.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def export_training_trace(self, session_id: str, prompt_text: str) -> dict:
323323
324324
Returns:
325325
dict: The trace dictionary containing `input_ids`, `labels`, `logprobs`,
326-
and `routed_experts`.
326+
`routed_experts`, and per-segment token spans.
327327
328328
Raises:
329329
ValueError: If the prompt_text does not completely match the trace keys in the session.
@@ -353,17 +353,34 @@ def export_training_trace(self, session_id: str, prompt_text: str) -> dict:
353353
f"prompt_len={len(prompt_text)} matched_len={len(key)} key_count={len(session_keys)}. "
354354
"See the logged '[TraceStore] prompt mismatch' report for the full diff."
355355
)
356-
trace: dict[str, list[Any]] = {"input_ids": [], "labels": [], "logprobs": [], "routed_experts": []}
356+
trace: dict[str, list[Any]] = {
357+
"input_ids": [],
358+
"labels": [],
359+
"logprobs": [],
360+
"routed_experts": [],
361+
"segments": [],
362+
}
357363
for node in nodes:
358364
node_val = node.value
359365
if not isinstance(node_val, TokenizedSegment):
360366
raise TypeError(f"Unexpected trace node value type: {type(node_val)!r}")
361367
assert node_val.labels is not None
362368
assert node_val.logprobs is not None
369+
start = len(trace["input_ids"])
370+
end = start + len(node_val.token_ids)
371+
trainable = any(label != -100 for label in node_val.labels)
363372
trace["input_ids"].extend(node_val.token_ids)
364373
trace["labels"].extend(node_val.labels)
365374
trace["logprobs"].extend(node_val.logprobs)
366375
trace["routed_experts"].append(node_val.expert_key)
376+
trace["segments"].append(
377+
{
378+
"start": start,
379+
"end": end,
380+
"trainable": trainable,
381+
"kind": "assistant_response" if trainable else "context_delta",
382+
}
383+
)
367384
return trace
368385

369386
def get_objects(self, keys: list[str]) -> list[ray.ObjectRef]:

0 commit comments

Comments
 (0)