Skip to content

Commit aeccf99

Browse files
[data] feat: TransferQueue - Unify the return of reward (verl-project#4902)
### What does this PR do? Unify the return values of the reward function to make the logic clearer ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [x] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent a3d6677 commit aeccf99

1 file changed

Lines changed: 20 additions & 30 deletions

File tree

verl/trainer/ppo/ray_trainer.py

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -538,9 +538,9 @@ def _compute_or_extract_reward(
538538
self,
539539
batch: DataProto,
540540
reward_fn=None,
541-
return_dict: bool = False,
541+
reward_for_val: bool = False,
542542
sum_reward: bool = False,
543-
) -> tuple[torch.Tensor, dict[str, Any]] | torch.Tensor | dict[str, Any]:
543+
) -> tuple[torch.Tensor, dict[str, Any]] | torch.Tensor:
544544
"""
545545
Compute or extract reward from batch.
546546
@@ -551,49 +551,39 @@ def _compute_or_extract_reward(
551551
Args:
552552
batch: DataProto containing the batch data
553553
reward_fn: Reward function to use if rm_scores doesn't exist (for training/validation)
554-
return_dict: Whether to return dict format with reward_extra_info (for validation)
554+
reward_for_val: Whether this is for validation
555555
sum_reward: Whether to sum reward tensor along last dimension (for REMAX baseline)
556556
557557
Returns:
558-
If return_dict=True: dict with "reward_tensor" and "reward_extra_info"
559-
If return_dict=False and sum_reward=True: summed reward_tensor (1D tensor)
560-
If return_dict=False and sum_reward=False: reward_tensor (2D tensor)
558+
If reward_for_val=False and sum_reward=True: summed reward_tensor (1D tensor)
559+
Otherwise: tuple of (reward_tensor, reward_extra_infos_dict)
561560
"""
562561
# When rm_scores already exists, extract it directly (format conversion only)
563562
if "rm_scores" in batch.batch.keys():
564563
reward_tensor = batch.batch["rm_scores"]
565564
if sum_reward:
566565
reward_tensor = reward_tensor.sum(dim=-1)
567566

568-
if return_dict:
569-
# Extract reward_extra_info if available
570-
reward_extra_keys = batch.meta_info.get("reward_extra_keys", [])
571-
reward_extra_info = (
572-
{key: batch.non_tensor_batch[key] for key in reward_extra_keys} if reward_extra_keys else {}
573-
)
574-
return {"reward_tensor": reward_tensor, "reward_extra_info": reward_extra_info}
575-
else:
576-
# If sum_reward=True, only return tensor (for REMAX baseline)
577-
if sum_reward:
578-
return reward_tensor
579-
# Otherwise, return tuple with reward_extra_info (for training loop)
580-
reward_extra_keys = batch.meta_info.get("reward_extra_keys", [])
581-
reward_extra_infos_dict = (
582-
{key: batch.non_tensor_batch[key] for key in reward_extra_keys} if reward_extra_keys else {}
583-
)
584-
return reward_tensor, reward_extra_infos_dict
567+
if not reward_for_val and sum_reward:
568+
return reward_tensor
569+
570+
reward_extra_keys = batch.meta_info.get("reward_extra_keys", [])
571+
reward_extra_infos_dict = (
572+
{key: batch.non_tensor_batch[key] for key in reward_extra_keys} if reward_extra_keys else {}
573+
)
574+
return reward_tensor, reward_extra_infos_dict
585575

586576
# Otherwise, compute reward using reward_fn
587577
if reward_fn is None:
588578
raise ValueError("reward_fn must be provided when rm_scores is not available.")
589579

590-
if return_dict:
580+
if reward_for_val:
591581
result = reward_fn(batch, return_dict=True)
592582
reward_tensor = result["reward_tensor"]
593583
if sum_reward:
594584
reward_tensor = reward_tensor.sum(dim=-1)
595-
reward_extra_info = result.get("reward_extra_info", {})
596-
return {"reward_tensor": reward_tensor, "reward_extra_info": reward_extra_info}
585+
reward_extra_infos_dict = result.get("reward_extra_info", {})
586+
return reward_tensor, reward_extra_infos_dict
597587
else:
598588
reward_tensor, reward_extra_infos_dict = compute_reward(batch, reward_fn)
599589
if sum_reward:
@@ -695,13 +685,13 @@ def _validate(self, merged: bool = False):
695685
sample_uids.extend(test_batch.non_tensor_batch["uid"])
696686

697687
# evaluate using reward_function
698-
result = self._compute_or_extract_reward(test_batch, reward_fn=self.val_reward_fn, return_dict=True)
699-
reward_tensor = result["reward_tensor"]
688+
reward_tensor, reward_extra_info = self._compute_or_extract_reward(
689+
test_batch, reward_fn=self.val_reward_fn, reward_for_val=True
690+
)
700691
scores = reward_tensor.sum(-1).cpu().tolist()
701692
sample_scores.extend(scores)
702693

703694
reward_extra_infos_dict["reward"].extend(scores)
704-
reward_extra_info = result.get("reward_extra_info", {})
705695
for key, values in reward_extra_info.items():
706696
if key not in reward_extra_infos_dict:
707697
reward_extra_infos_dict[key] = []
@@ -1525,7 +1515,7 @@ def fit(self):
15251515
)
15261516
else:
15271517
reward_tensor, reward_extra_infos_dict = self._compute_or_extract_reward(
1528-
batch, reward_fn=self.reward_fn, return_dict=False
1518+
batch, reward_fn=self.reward_fn, reward_for_val=False
15291519
)
15301520

15311521
# Operating Mode Selection:

0 commit comments

Comments
 (0)