[algo] fix: strip '+' suffix in kl_penalty so k3+/low_var_kl+ work#6058
Merged
tongyx361 merged 1 commit intoverl-project:mainfrom Apr 20, 2026
Merged
Conversation
`kl_penalty()` switches on `kl_penalty.endswith("+")` to enable the
straight-through trick, but forwards the raw string to
`kl_penalty_forward()` without stripping the trailing `+`. Any
`+`-suffixed estimator (`k1+`, `kl+`, `abs+`, `k3+`, `low_var_kl+`)
therefore falls through every branch and hits `raise NotImplementedError`,
crashing the first training step that touches KL (via
`algorithm.kl_penalty` or `actor.kl_loss_type`).
Strip the suffix before dispatching, and add CPU regression tests that
the suffixed value matches the base estimator and that `k3+` produces
the k2 gradient.
Contributor
There was a problem hiding this comment.
Code Review
This pull request fixes a bug in the kl_penalty function by stripping the '+' suffix before dispatching to the forward pass, preventing NotImplementedError for straight-through estimators like 'k3+'. It also introduces regression tests to verify that forward values match their base estimators and that gradients are correctly routed through the 'k2' estimator. I have no feedback to provide.
tongyx361
approved these changes
Apr 20, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Fix a long-standing bug in
core_algos.kl_penaltythat makes every+-suffixed estimator name (k1+,kl+,abs+,k3+,low_var_kl+)crash with
NotImplementedErroron the first training step that touchesKL — either through
algorithm.kl_penalty(KL-in-reward) oractor.kl_loss_type(actor KL loss).The straight-through trick that the
+suffix is supposed to enable(unbiased k3 value with unbiased k2 gradient, see
Schulman, Approximating KL divergence)
was added in #2953 but has never actually worked end-to-end because the
suffix was forwarded to
kl_penalty_forwardwithout being stripped,causing the dispatch to fall through to
raise NotImplementedError.Checklist Before Starting
is:pr kl_penalty k3+— 0 open, 2 closed (docstring typo and a rename, neither addresses this bug)is:pr kl_penalty endswith— 0 resultsis:pr kl_penalty_forward— 0 resultsis:issue "k3+"— no related issue[{modules}] {type}: {description}—[algo] fix: …. Single-line behavior fix inverl/trainer/ppo/core_algos.py, no API surface change.Test
Added two CPU regression tests in
tests/trainer/ppo/test_core_algos_on_cpu.pythat get picked upautomatically by
cpu_unit_tests.yml(tests/**/test_*_on_cpu.py):test_kl_penalty_straight_through_value_matches_base— parametrizedover
(k1+, kl+, abs+, k3+, low_var_kl+), asserts that the suffixedestimator returns the same forward value as its base.
test_kl_penalty_k3_plus_uses_k2_gradient— asserts that thegradient w.r.t.
logprobproduced byk3+is exactly the gradientproduced by
k2, i.e. the straight-through trick is wired correctly.Local run on top of the change:
Without the fix, every
..._straight_through_value_matches_base[*]case raises
NotImplementedErrorinkl_penalty_forward.API and Usage Example
No API change. The fix simply makes the previously-documented
+suffix work as advertised. After this PR the following configurations
are usable (they currently crash on step 1):
Design & Code Changes
Root cause is a single-line oversight in
core_algos.kl_penalty:kl_penalty_forwardonly recognizes the bare names (k1,kl,abs,mse,k2,k3,low_var_kl,full), so it falls through toraise NotImplementedErrorwhenever the input ends with+.Fix: strip the optional suffix before dispatch, leaving the outer
endswith("+")switch untouched so the straight-through path is stillselected correctly:
Files changed:
verl/trainer/ppo/core_algos.py— 1-line fix + 1-line comment.tests/trainer/ppo/test_core_algos_on_cpu.py— 2 regression tests.Checklist Before Submitting
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always— all 12 hooks pass on the touched files (ruff,ruff-format,mypy,autogen-trainer-cfg,check-docs-time-info,check-docstrings,check-license,check-device-api-usage,check-dataproto-usage,validate-structure,check-naming-conventions,compileall).kl_penaltydocstring).tests/trainer/ppo/test_core_algos_on_cpu.pyand are picked up automatically bycpu_unit_tests.yml(tests/**/test_*_on_cpu.py).ci-requestchannel.recipesubmodule.