-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[data] feat: TransferQueue - Unify the return of reward #4902
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
51b236d
42e16e0
8ca8ba9
775dfd8
f514c38
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -538,9 +538,9 @@ def _compute_or_extract_reward( | |
| self, | ||
| batch: DataProto, | ||
| reward_fn=None, | ||
| return_dict: bool = False, | ||
| reward_for_val: bool = False, | ||
| sum_reward: bool = False, | ||
| ) -> tuple[torch.Tensor, dict[str, Any]] | torch.Tensor | dict[str, Any]: | ||
| ) -> tuple[torch.Tensor, dict[str, Any]] | torch.Tensor: | ||
| """ | ||
| Compute or extract reward from batch. | ||
|
|
||
|
|
@@ -551,49 +551,40 @@ def _compute_or_extract_reward( | |
| Args: | ||
| batch: DataProto containing the batch data | ||
| reward_fn: Reward function to use if rm_scores doesn't exist (for training/validation) | ||
| return_dict: Whether to return dict format with reward_extra_info (for validation) | ||
| reward_for_val: Calculate reward for validation (for validation) | ||
| sum_reward: Whether to sum reward tensor along last dimension (for REMAX baseline) | ||
|
|
||
| Returns: | ||
| If return_dict=True: dict with "reward_tensor" and "reward_extra_info" | ||
| If return_dict=False and sum_reward=True: summed reward_tensor (1D tensor) | ||
| If return_dict=False and sum_reward=False: reward_tensor (2D tensor) | ||
| If reward_for_val=True: reward_tensor (2D tensor) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add descriptions for
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
| If reward_for_val=False and sum_reward=True: summed reward_tensor (1D tensor) | ||
| If reward_for_val=False and sum_reward=False: reward_tensor (2D tensor) | ||
|
walterchenchn marked this conversation as resolved.
Outdated
|
||
| """ | ||
| # When rm_scores already exists, extract it directly (format conversion only) | ||
| if "rm_scores" in batch.batch.keys(): | ||
| reward_tensor = batch.batch["rm_scores"] | ||
| if sum_reward: | ||
| reward_tensor = reward_tensor.sum(dim=-1) | ||
|
|
||
| if return_dict: | ||
| # Extract reward_extra_info if available | ||
| reward_extra_keys = batch.meta_info.get("reward_extra_keys", []) | ||
| reward_extra_info = ( | ||
| {key: batch.non_tensor_batch[key] for key in reward_extra_keys} if reward_extra_keys else {} | ||
| ) | ||
| return {"reward_tensor": reward_tensor, "reward_extra_info": reward_extra_info} | ||
| else: | ||
| # If sum_reward=True, only return tensor (for REMAX baseline) | ||
| if sum_reward: | ||
| return reward_tensor | ||
| # Otherwise, return tuple with reward_extra_info (for training loop) | ||
| reward_extra_keys = batch.meta_info.get("reward_extra_keys", []) | ||
| reward_extra_infos_dict = ( | ||
| {key: batch.non_tensor_batch[key] for key in reward_extra_keys} if reward_extra_keys else {} | ||
| ) | ||
| return reward_tensor, reward_extra_infos_dict | ||
| if not reward_for_val and sum_reward: | ||
| return reward_tensor | ||
|
|
||
| reward_extra_keys = batch.meta_info.get("reward_extra_keys", []) | ||
| reward_extra_infos_dict = ( | ||
| {key: batch.non_tensor_batch[key] for key in reward_extra_keys} if reward_extra_keys else {} | ||
| ) | ||
| return reward_tensor, reward_extra_infos_dict | ||
|
|
||
| # Otherwise, compute reward using reward_fn | ||
| if reward_fn is None: | ||
| raise ValueError("reward_fn must be provided when rm_scores is not available.") | ||
|
|
||
| if return_dict: | ||
| if reward_for_val: | ||
| result = reward_fn(batch, return_dict=True) | ||
| reward_tensor = result["reward_tensor"] | ||
| if sum_reward: | ||
| reward_tensor = reward_tensor.sum(dim=-1) | ||
| reward_extra_info = result.get("reward_extra_info", {}) | ||
| return {"reward_tensor": reward_tensor, "reward_extra_info": reward_extra_info} | ||
| return reward_tensor, reward_extra_info | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unify with
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
| else: | ||
| reward_tensor, reward_extra_infos_dict = compute_reward(batch, reward_fn) | ||
| if sum_reward: | ||
|
|
@@ -695,13 +686,11 @@ def _validate(self, merged: bool = False): | |
| sample_uids.extend(test_batch.non_tensor_batch["uid"]) | ||
|
|
||
| # evaluate using reward_function | ||
| result = self._compute_or_extract_reward(test_batch, reward_fn=self.val_reward_fn, return_dict=True) | ||
| reward_tensor = result["reward_tensor"] | ||
| reward_tensor, reward_extra_info = self._compute_or_extract_reward(test_batch, reward_fn=self.val_reward_fn, reward_for_val=True) | ||
| scores = reward_tensor.sum(-1).cpu().tolist() | ||
| sample_scores.extend(scores) | ||
|
|
||
| reward_extra_infos_dict["reward"].extend(scores) | ||
| reward_extra_info = result.get("reward_extra_info", {}) | ||
| for key, values in reward_extra_info.items(): | ||
| if key not in reward_extra_infos_dict: | ||
| reward_extra_infos_dict[key] = [] | ||
|
|
@@ -1525,7 +1514,7 @@ def fit(self): | |
| ) | ||
| else: | ||
| reward_tensor, reward_extra_infos_dict = self._compute_or_extract_reward( | ||
| batch, reward_fn=self.reward_fn, return_dict=False | ||
| batch, reward_fn=self.reward_fn, reward_for_val=False | ||
| ) | ||
|
|
||
| # Operating Mode Selection: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.