Skip to content

Commit 1d14d57

Browse files
authored
[dev] fix: validation metrics (volcengine#1374)
### Checklist Before Starting - [x] Search for similar PR(s). ### What does this PR do? 1. Fix the error that `metric` is not added when `n == 1`. 2. Remove `std@1`. 3. Add assertation for doing initial validation but `val_metrics` is empty. ### Additional Info. - **Issue Number**: none - **Training**: none - **Inference**: none ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [x] Add `[BREAKING]` to the PR title if it breaks any API. - [x] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [x] Add CI test(s) if necessary.
1 parent 069a417 commit 1d14d57

File tree

4 files changed

+33
-33
lines changed

4 files changed

+33
-33
lines changed

recipe/dapo/src/dapo_ray_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def fit(self):
6767
# currently, we only support validation using the reward_function.
6868
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
6969
val_metrics = self._validate()
70+
assert val_metrics, f"{val_metrics=}"
7071
pprint(f"Initial validation metrics: {val_metrics}")
7172
logger.log(data=val_metrics, step=self.global_steps)
7273
if self.config.trainer.get("val_only", False):

recipe/prime/prime_ray_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ def fit(self):
327327
# currently, we only support validation using the reward_function.
328328
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
329329
val_metrics = self._validate()
330+
assert val_metrics, f"{val_metrics=}"
330331
pprint(f"Initial validation metrics: {val_metrics}")
331332
logger.log(data=val_metrics, step=self.global_steps)
332333
if self.config.trainer.get("val_only", False):

verl/trainer/ppo/metric_utils.py

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -208,45 +208,42 @@ def process_validation_metrics(data_sources: list[str], sample_inputs: list[str]
208208
for var_name, var_vals in var2vals.items():
209209
if isinstance(var_vals[0], str):
210210
continue
211+
211212
metric = {}
212213
n_resps = len(var_vals)
213214
metric[f"mean@{n_resps}"] = np.mean(var_vals)
214-
metric[f"std@{n_resps}"] = np.std(var_vals)
215-
216-
ns = []
217-
n = 2
218-
while n < n_resps:
219-
ns.append(n)
220-
n *= 2
221-
ns.append(n_resps)
222215

223-
# If there are multiple responses, we can compute the best/worst-of-N metrics
224-
# If not, they are the same as the single response metrics
225216
if n_resps > 1:
217+
# n = n_resps
218+
metric[f"std@{n_resps}"] = np.std(var_vals)
219+
220+
metric[f"best@{n_resps}/mean"] = np.max(var_vals)
221+
metric[f"worst@{n_resps}/mean"] = np.min(var_vals)
222+
if var2vals.get("pred", None) is not None:
223+
vote_data = [{"val": val, "pred": pred} for val, pred in zip(var_vals, var2vals["pred"])]
224+
metric[f"maj@{n_resps}/mean"] = calc_maj_val(vote_data, vote_key="pred", val_key="val")
225+
# 1 < n < n_resps
226+
ns = []
227+
n = 2
228+
while n < n_resps:
229+
ns.append(n)
230+
n *= 2
231+
226232
for n in ns:
227-
if n == n_resps:
228-
# Non-bootstrapped
229-
metric[f"best@{n}/mean"] = np.max(var_vals)
230-
metric[f"worst@{n}/mean"] = np.min(var_vals)
231-
if var2vals.get("pred", None) is not None:
232-
vote_data = [{"val": val, "pred": pred} for val, pred in zip(var_vals, var2vals["pred"])]
233-
metric[f"maj@{n}/mean"] = calc_maj_val(vote_data, vote_key="pred", val_key="val")
234-
else:
235-
# Bootstrapped
236-
[(bon_mean, bon_std), (won_mean, won_std)] = bootstrap_metric(data=var_vals, subset_size=n, reduce_fns=[np.max, np.min], seed=seed)
237-
metric[f"best@{n}/mean"], metric[f"best@{n}/std"] = bon_mean, bon_std
238-
metric[f"worst@{n}/mean"], metric[f"worst@{n}/std"] = won_mean, won_std
239-
if var2vals.get("pred", None) is not None:
240-
vote_data = [{"val": val, "pred": pred} for val, pred in zip(var_vals, var2vals["pred"])]
241-
[(maj_n_mean, maj_n_std)] = bootstrap_metric(
242-
data=vote_data,
243-
subset_size=n,
244-
reduce_fns=[partial(calc_maj_val, vote_key="pred", val_key="val")],
245-
seed=seed,
246-
)
247-
metric[f"maj@{n}/mean"], metric[f"maj@{n}/std"] = maj_n_mean, maj_n_std
248-
249-
data_src2prompt2var2metric[data_source][prompt][var_name] = metric
233+
[(bon_mean, bon_std), (won_mean, won_std)] = bootstrap_metric(data=var_vals, subset_size=n, reduce_fns=[np.max, np.min], seed=seed)
234+
metric[f"best@{n}/mean"], metric[f"best@{n}/std"] = bon_mean, bon_std
235+
metric[f"worst@{n}/mean"], metric[f"worst@{n}/std"] = won_mean, won_std
236+
if var2vals.get("pred", None) is not None:
237+
vote_data = [{"val": val, "pred": pred} for val, pred in zip(var_vals, var2vals["pred"])]
238+
[(maj_n_mean, maj_n_std)] = bootstrap_metric(
239+
data=vote_data,
240+
subset_size=n,
241+
reduce_fns=[partial(calc_maj_val, vote_key="pred", val_key="val")],
242+
seed=seed,
243+
)
244+
metric[f"maj@{n}/mean"], metric[f"maj@{n}/std"] = maj_n_mean, maj_n_std
245+
246+
data_src2prompt2var2metric[data_source][prompt][var_name] = metric
250247

251248
# Aggregate metrics across prompts
252249
data_src2var2metric2prompt_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))

verl/trainer/ppo/ray_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,7 @@ def fit(self):
869869
# currently, we only support validation using the reward_function.
870870
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
871871
val_metrics = self._validate()
872+
assert val_metrics, f"{val_metrics=}"
872873
pprint(f"Initial validation metrics: {val_metrics}")
873874
logger.log(data=val_metrics, step=self.global_steps)
874875
if self.config.trainer.get("val_only", False):

0 commit comments

Comments
 (0)