Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 18 additions & 29 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,9 +538,9 @@ def _compute_or_extract_reward(
self,
batch: DataProto,
reward_fn=None,
return_dict: bool = False,
reward_for_val: bool = False,
sum_reward: bool = False,
) -> tuple[torch.Tensor, dict[str, Any]] | torch.Tensor | dict[str, Any]:
) -> tuple[torch.Tensor, dict[str, Any]] | torch.Tensor:
"""
Compute or extract reward from batch.

Expand All @@ -551,49 +551,40 @@ def _compute_or_extract_reward(
Args:
batch: DataProto containing the batch data
reward_fn: Reward function to use if rm_scores doesn't exist (for training/validation)
return_dict: Whether to return dict format with reward_extra_info (for validation)
reward_for_val: Calculate reward for validation (for validation)
Comment thread
walterchenchn marked this conversation as resolved.
Outdated
sum_reward: Whether to sum reward tensor along last dimension (for REMAX baseline)

Returns:
If return_dict=True: dict with "reward_tensor" and "reward_extra_info"
If return_dict=False and sum_reward=True: summed reward_tensor (1D tensor)
If return_dict=False and sum_reward=False: reward_tensor (2D tensor)
If reward_for_val=True: reward_tensor (2D tensor)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add descriptions for reward_extra_infos

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

If reward_for_val=False and sum_reward=True: summed reward_tensor (1D tensor)
If reward_for_val=False and sum_reward=False: reward_tensor (2D tensor)
Comment thread
walterchenchn marked this conversation as resolved.
Outdated
"""
# When rm_scores already exists, extract it directly (format conversion only)
if "rm_scores" in batch.batch.keys():
reward_tensor = batch.batch["rm_scores"]
if sum_reward:
reward_tensor = reward_tensor.sum(dim=-1)

if return_dict:
# Extract reward_extra_info if available
reward_extra_keys = batch.meta_info.get("reward_extra_keys", [])
reward_extra_info = (
{key: batch.non_tensor_batch[key] for key in reward_extra_keys} if reward_extra_keys else {}
)
return {"reward_tensor": reward_tensor, "reward_extra_info": reward_extra_info}
else:
# If sum_reward=True, only return tensor (for REMAX baseline)
if sum_reward:
return reward_tensor
# Otherwise, return tuple with reward_extra_info (for training loop)
reward_extra_keys = batch.meta_info.get("reward_extra_keys", [])
reward_extra_infos_dict = (
{key: batch.non_tensor_batch[key] for key in reward_extra_keys} if reward_extra_keys else {}
)
return reward_tensor, reward_extra_infos_dict
if not reward_for_val and sum_reward:
return reward_tensor

reward_extra_keys = batch.meta_info.get("reward_extra_keys", [])
reward_extra_infos_dict = (
{key: batch.non_tensor_batch[key] for key in reward_extra_keys} if reward_extra_keys else {}
)
return reward_tensor, reward_extra_infos_dict

# Otherwise, compute reward using reward_fn
if reward_fn is None:
raise ValueError("reward_fn must be provided when rm_scores is not available.")

if return_dict:
if reward_for_val:
result = reward_fn(batch, return_dict=True)
reward_tensor = result["reward_tensor"]
if sum_reward:
reward_tensor = reward_tensor.sum(dim=-1)
reward_extra_info = result.get("reward_extra_info", {})
return {"reward_tensor": reward_tensor, "reward_extra_info": reward_extra_info}
return reward_tensor, reward_extra_info
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unify with reward_extra_info_dict?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

else:
reward_tensor, reward_extra_infos_dict = compute_reward(batch, reward_fn)
if sum_reward:
Expand Down Expand Up @@ -695,13 +686,11 @@ def _validate(self, merged: bool = False):
sample_uids.extend(test_batch.non_tensor_batch["uid"])

# evaluate using reward_function
result = self._compute_or_extract_reward(test_batch, reward_fn=self.val_reward_fn, return_dict=True)
reward_tensor = result["reward_tensor"]
reward_tensor, reward_extra_info = self._compute_or_extract_reward(test_batch, reward_fn=self.val_reward_fn, reward_for_val=True)
scores = reward_tensor.sum(-1).cpu().tolist()
sample_scores.extend(scores)

reward_extra_infos_dict["reward"].extend(scores)
reward_extra_info = result.get("reward_extra_info", {})
for key, values in reward_extra_info.items():
if key not in reward_extra_infos_dict:
reward_extra_infos_dict[key] = []
Expand Down Expand Up @@ -1525,7 +1514,7 @@ def fit(self):
)
else:
reward_tensor, reward_extra_infos_dict = self._compute_or_extract_reward(
batch, reward_fn=self.reward_fn, return_dict=False
batch, reward_fn=self.reward_fn, reward_for_val=False
)

# Operating Mode Selection:
Expand Down
Loading