Skip to content

Commit fba0939

Browse files
Jackie2049claudeboundless-future
authored
[trainer] fix: return NaN for empty tensors in compute_data_metrics (#5899)
### What does this PR do? Fixes #5894 When all samples are aborted or `response_mask` is all `False`, `compute_data_metrics` crashes with: ``` RuntimeError: max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument. ``` This PR replaces the crash with a warning + NaN return, so training can continue gracefully in edge cases such as: - **System debugging**: early training runs may produce invalid responses; crashing blocks iteration - **Unique sampling strategies**: some strategies may occasionally produce batches where all responses are invalid Note: `compute_data_metrics` is purely a logging/monitoring function — its return values do not feed into loss computation. When all responses are invalid, the actual loss gradient is 0 (via `masked_sum / (num_tokens + 1e-8) = 0`), meaning no parameter updates occur, which is the expected behavior. ### Checklist Before Starting - [x] Search for similar PRs: - https://github.com/volcengine/verl/pulls?q=is%3Apr+compute_data_metrics - Found related PR #5860 which fixes a similar issue in `calculate_debug_metrics` ### Test This is a defensive bug fix for an edge case. ### API and Usage Example No API changes. The function signature remains the same: ```python from verl.trainer.ppo.metric_utils import compute_data_metrics # Normal case - works as before metrics = compute_data_metrics(batch) # Returns: {'critic/score/mean': 0.5, 'critic/rewards/max': 1.0, ...} # Edge case - now handled gracefully instead of crashing metrics = compute_data_metrics(batch_with_all_aborted) # Logs warning, returns: {'critic/score/mean': nan, 'critic/rewards/max': nan, ...} # Training continues normally (loss gradient = 0, no parameter updates) ``` ### Design & Code Changes **File changed:** `verl/trainer/ppo/metric_utils.py` **Root Cause:** `torch.max()` / `torch.min()` operations on empty tensors raise RuntimeError when: - All samples are aborted (empty `non_aborted_sequence_score/reward`) - `response_mask` is all False (empty `valid_adv/valid_returns/valid_values`) **Solution:** Add `numel() > 0` checks with NaN fallback (same pattern as PR #5860): ```python if non_aborted_sequence_score.numel() > 0: score_mean = torch.mean(non_aborted_sequence_score).detach().item() score_max = torch.max(non_aborted_sequence_score).detach().item() score_min = torch.min(non_aborted_sequence_score).detach().item() else: logger.warning("All samples are aborted, returning default score metrics") score_mean = score_max = score_min = float("nan") ``` Applied to: - `non_aborted_sequence_score/reward` → score/reward metrics - `valid_adv/valid_returns` → advantage/return metrics - `valid_values` → critic/values metrics (when use_critic=True) - `non_aborted_response_length` → response length metrics ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md) - [x] Apply pre-commit checks (ruff check & format passed locally) - [x] Add / Update documentation (N/A - internal bug fix) - [x] Add unit or end-to-end test(s) (N/A - edge case, no tests needed) --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Boundless <ruihang_wu@163.com>
1 parent 43d1c6f commit fba0939

File tree

1 file changed

+70
-30
lines changed

1 file changed

+70
-30
lines changed

verl/trainer/ppo/metric_utils.py

Lines changed: 70 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
Metrics related to the PPO trainer.
1616
"""
1717

18+
import logging
1819
from collections import defaultdict
1920
from functools import partial
2021
from typing import Any, Callable
@@ -26,6 +27,8 @@
2627
from verl import DataProto
2728
from verl.utils.import_utils import deprecated
2829

30+
logger = logging.getLogger(__name__)
31+
2932

3033
@deprecated("verl.utils.metric.reduce_metrics")
3134
def reduce_metrics(metrics: dict[str, list[Any]]) -> dict[str, Any]:
@@ -128,22 +131,40 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str,
128131
non_aborted_sequence_score = sequence_score[non_aborted_mask]
129132
non_aborted_sequence_reward = sequence_reward[non_aborted_mask]
130133

131-
score_mean = torch.mean(non_aborted_sequence_score).detach().item()
132-
score_max = torch.max(non_aborted_sequence_score).detach().item()
133-
score_min = torch.min(non_aborted_sequence_score).detach().item()
134+
if non_aborted_sequence_score.numel() > 0:
135+
score_mean = torch.mean(non_aborted_sequence_score).detach().item()
136+
score_max = torch.max(non_aborted_sequence_score).detach().item()
137+
score_min = torch.min(non_aborted_sequence_score).detach().item()
138+
else:
139+
logger.warning("All samples are aborted, returning default score metrics")
140+
score_mean = score_max = score_min = float("nan")
134141

135-
reward_mean = torch.mean(non_aborted_sequence_reward).detach().item()
136-
reward_max = torch.max(non_aborted_sequence_reward).detach().item()
137-
reward_min = torch.min(non_aborted_sequence_reward).detach().item()
142+
if non_aborted_sequence_reward.numel() > 0:
143+
reward_mean = torch.mean(non_aborted_sequence_reward).detach().item()
144+
reward_max = torch.max(non_aborted_sequence_reward).detach().item()
145+
reward_min = torch.min(non_aborted_sequence_reward).detach().item()
146+
else:
147+
logger.warning("All samples are aborted, returning default reward metrics")
148+
reward_mean = reward_max = reward_min = float("nan")
138149

139150
valid_adv = torch.masked_select(advantages, response_mask)
140151
valid_returns = torch.masked_select(returns, response_mask)
141152

142-
if use_critic:
143-
values = batch.batch["values"]
144-
valid_values = torch.masked_select(values, response_mask)
145-
return_diff_var = torch.var(valid_returns - valid_values)
146-
return_var = torch.var(valid_returns)
153+
if valid_adv.numel() > 0:
154+
adv_mean = torch.mean(valid_adv).detach().item()
155+
adv_max = torch.max(valid_adv).detach().item()
156+
adv_min = torch.min(valid_adv).detach().item()
157+
else:
158+
logger.warning("Response mask is all False, returning default advantage metrics")
159+
adv_mean = adv_max = adv_min = float("nan")
160+
161+
if valid_returns.numel() > 0:
162+
returns_mean = torch.mean(valid_returns).detach().item()
163+
returns_max = torch.max(valid_returns).detach().item()
164+
returns_min = torch.min(valid_returns).detach().item()
165+
else:
166+
logger.warning("Response mask is all False, returning default return metrics")
167+
returns_mean = returns_max = returns_min = float("nan")
147168

148169
# Aborted samples and non-aborted response length statistics
149170
# response_length_non_aborted/*: statistics computed on non-aborted samples only
@@ -158,7 +179,37 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str,
158179
torch.mean(torch.eq(non_aborted_response_length, max_response_length).float()).detach().item()
159180
)
160181
else:
161-
raise ValueError("All samples are aborted, this should not happen.")
182+
logger.warning("All samples are aborted, returning default response length metrics")
183+
non_aborted_response_length_mean = float("nan")
184+
non_aborted_response_length_max = float("nan")
185+
non_aborted_response_length_min = float("nan")
186+
non_aborted_response_length_clip_ratio = float("nan")
187+
188+
if use_critic:
189+
values = batch.batch["values"]
190+
valid_values = torch.masked_select(values, response_mask)
191+
if valid_returns.numel() > 0 and valid_values.numel() > 0:
192+
return_diff_var = torch.var(valid_returns - valid_values)
193+
return_var = torch.var(valid_returns)
194+
critic_value_metrics = {
195+
# values
196+
"critic/values/mean": torch.mean(valid_values).detach().item(),
197+
"critic/values/max": torch.max(valid_values).detach().item(),
198+
"critic/values/min": torch.min(valid_values).detach().item(),
199+
# vf explained var
200+
"critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
201+
}
202+
else:
203+
logger.warning("Response mask is all False, returning default value metrics")
204+
critic_value_metrics = {
205+
"critic/values/mean": float("nan"),
206+
"critic/values/max": float("nan"),
207+
"critic/values/min": float("nan"),
208+
# vf explained var
209+
"critic/vf_explained_var": float("nan"),
210+
}
211+
else:
212+
critic_value_metrics = {}
162213

163214
metrics = {
164215
# score
@@ -170,25 +221,14 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str,
170221
"critic/rewards/max": reward_max,
171222
"critic/rewards/min": reward_min,
172223
# adv
173-
"critic/advantages/mean": torch.mean(valid_adv).detach().item(),
174-
"critic/advantages/max": torch.max(valid_adv).detach().item(),
175-
"critic/advantages/min": torch.min(valid_adv).detach().item(),
224+
"critic/advantages/mean": adv_mean,
225+
"critic/advantages/max": adv_max,
226+
"critic/advantages/min": adv_min,
176227
# returns
177-
"critic/returns/mean": torch.mean(valid_returns).detach().item(),
178-
"critic/returns/max": torch.max(valid_returns).detach().item(),
179-
"critic/returns/min": torch.min(valid_returns).detach().item(),
180-
**(
181-
{
182-
# values
183-
"critic/values/mean": torch.mean(valid_values).detach().item(),
184-
"critic/values/max": torch.max(valid_values).detach().item(),
185-
"critic/values/min": torch.min(valid_values).detach().item(),
186-
# vf explained var
187-
"critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
188-
}
189-
if use_critic
190-
else {}
191-
),
228+
"critic/returns/mean": returns_mean,
229+
"critic/returns/max": returns_max,
230+
"critic/returns/min": returns_min,
231+
**critic_value_metrics,
192232
# response length
193233
"response_length/mean": torch.mean(response_length).detach().item(),
194234
"response_length/max": torch.max(response_length).detach().item(),

0 commit comments

Comments
 (0)