Skip to content

[algo] fix: strip '+' suffix in kl_penalty so k3+/low_var_kl+ work#6058

Merged
tongyx361 merged 1 commit intoverl-project:mainfrom
MaxwellJryao:fix/kl-penalty-plus-suffix
Apr 20, 2026
Merged

[algo] fix: strip '+' suffix in kl_penalty so k3+/low_var_kl+ work#6058
tongyx361 merged 1 commit intoverl-project:mainfrom
MaxwellJryao:fix/kl-penalty-plus-suffix

Conversation

@MaxwellJryao
Copy link
Copy Markdown
Contributor

What does this PR do?

Fix a long-standing bug in core_algos.kl_penalty that makes every
+-suffixed estimator name (k1+, kl+, abs+, k3+, low_var_kl+)
crash with NotImplementedError on the first training step that touches
KL — either through algorithm.kl_penalty (KL-in-reward) or
actor.kl_loss_type (actor KL loss).

The straight-through trick that the + suffix is supposed to enable
(unbiased k3 value with unbiased k2 gradient, see
Schulman, Approximating KL divergence)
was added in #2953 but has never actually worked end-to-end because the
suffix was forwarded to kl_penalty_forward without being stripped,
causing the dispatch to fall through to raise NotImplementedError.

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here:
  • Format the PR title as [{modules}] {type}: {description}[algo] fix: …. Single-line behavior fix in verl/trainer/ppo/core_algos.py, no API surface change.

Test

Added two CPU regression tests in
tests/trainer/ppo/test_core_algos_on_cpu.py that get picked up
automatically by cpu_unit_tests.yml (tests/**/test_*_on_cpu.py):

  1. test_kl_penalty_straight_through_value_matches_base — parametrized
    over (k1+, kl+, abs+, k3+, low_var_kl+), asserts that the suffixed
    estimator returns the same forward value as its base.
  2. test_kl_penalty_k3_plus_uses_k2_gradient — asserts that the
    gradient w.r.t. logprob produced by k3+ is exactly the gradient
    produced by k2, i.e. the straight-through trick is wired correctly.

Local run on top of the change:

$ pytest -xvs tests/trainer/ppo/test_core_algos_on_cpu.py
============================= test session starts ==============================
collected 22 items
tests/trainer/ppo/test_core_algos_on_cpu.py ......................   [100%]
============================== 22 passed in 4.79s ==============================

Without the fix, every ..._straight_through_value_matches_base[*]
case raises NotImplementedError in kl_penalty_forward.

API and Usage Example

No API change. The fix simply makes the previously-documented +
suffix work as advertised. After this PR the following configurations
are usable (they currently crash on step 1):

# KL-in-reward with the straight-through trick:
python -m verl.trainer.main_ppo \
    algorithm.use_kl_in_reward=True \
    algorithm.kl_penalty=k3+ \
    algorithm.kl_ctrl.kl_coef=1e-3

# Actor KL loss with the straight-through trick:
python -m verl.trainer.main_ppo \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl+ \
    actor_rollout_ref.actor.kl_loss_coef=1e-3

Design & Code Changes

Root cause is a single-line oversight in core_algos.kl_penalty:

# before
forward_score = kl_penalty_forward(logprob, ref_logprob, kl_penalty)
if not kl_penalty.endswith("+") or kl_penalty in ("mse", "k2"):
    return forward_score

kl_penalty_forward only recognizes the bare names (k1, kl, abs,
mse, k2, k3, low_var_kl, full), so it falls through to
raise NotImplementedError whenever the input ends with +.

Fix: strip the optional suffix before dispatch, leaving the outer
endswith("+") switch untouched so the straight-through path is still
selected correctly:

# after
base_kl_penalty = kl_penalty[:-1] if kl_penalty.endswith("+") else kl_penalty
forward_score = kl_penalty_forward(logprob, ref_logprob, base_kl_penalty)
if not kl_penalty.endswith("+") or kl_penalty in ("mse", "k2"):
    return forward_score

Files changed:

  • verl/trainer/ppo/core_algos.py — 1-line fix + 1-line comment.
  • tests/trainer/ppo/test_core_algos_on_cpu.py — 2 regression tests.

Checklist Before Submitting

  • Read the Contribute Guide.
  • Apply pre-commit checks: pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always — all 12 hooks pass on the touched files (ruff, ruff-format, mypy, autogen-trainer-cfg, check-docs-time-info, check-docstrings, check-license, check-device-api-usage, check-dataproto-usage, validate-structure, check-naming-conventions, compileall).
  • Add / Update the documentation — N/A (bug fix; the suffix was already documented in the kl_penalty docstring).
  • Add unit or end-to-end test(s) to the CI workflow. The new tests are added to tests/trainer/ppo/test_core_algos_on_cpu.py and are picked up automatically by cpu_unit_tests.yml (tests/**/test_*_on_cpu.py).
  • Once your PR is ready for CI, send a message in the ci-request channel.
  • Not related to the recipe submodule.

`kl_penalty()` switches on `kl_penalty.endswith("+")` to enable the
straight-through trick, but forwards the raw string to
`kl_penalty_forward()` without stripping the trailing `+`. Any
`+`-suffixed estimator (`k1+`, `kl+`, `abs+`, `k3+`, `low_var_kl+`)
therefore falls through every branch and hits `raise NotImplementedError`,
crashing the first training step that touches KL (via
`algorithm.kl_penalty` or `actor.kl_loss_type`).

Strip the suffix before dispatching, and add CPU regression tests that
the suffixed value matches the base estimator and that `k3+` produces
the k2 gradient.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes a bug in the kl_penalty function by stripping the '+' suffix before dispatching to the forward pass, preventing NotImplementedError for straight-through estimators like 'k3+'. It also introduces regression tests to verify that forward values match their base estimators and that gradients are correctly routed through the 'k2' estimator. I have no feedback to provide.

@tongyx361 tongyx361 self-assigned this Apr 20, 2026
@tongyx361 tongyx361 merged commit 7e80ab0 into verl-project:main Apr 20, 2026
63 of 73 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants