Commit 7e80ab0
authored
[algo] fix: strip '+' suffix in kl_penalty so k3+/low_var_kl+ work (#6058)
### What does this PR do?
Fix a long-standing bug in `core_algos.kl_penalty` that makes every
`+`-suffixed estimator name (`k1+`, `kl+`, `abs+`, `k3+`, `low_var_kl+`)
crash with `NotImplementedError` on the first training step that touches
KL — either through `algorithm.kl_penalty` (KL-in-reward) or
`actor.kl_loss_type` (actor KL loss).
The straight-through trick that the `+` suffix is supposed to enable
(unbiased k3 value with unbiased k2 gradient, see
[Schulman, *Approximating KL
divergence*](http://joschu.net/blog/kl-approx.html))
was added in #2953 but has never actually worked end-to-end because the
suffix was forwarded to `kl_penalty_forward` without being stripped,
causing the dispatch to fall through to `raise NotImplementedError`.
### Checklist Before Starting
- [x] Search for similar PRs. Paste at least one query link here:
- [`is:pr kl_penalty
k3+`](https://github.com/verl-project/verl/pulls?q=is%3Apr+kl_penalty+k3%2B)
— 0 open, 2 closed (docstring typo and a rename, neither addresses this
bug)
- [`is:pr kl_penalty
endswith`](https://github.com/verl-project/verl/pulls?q=is%3Apr+kl_penalty+endswith)
— 0 results
- [`is:pr
kl_penalty_forward`](https://github.com/verl-project/verl/pulls?q=is%3Apr+%22kl_penalty_forward%22)
— 0 results
- [`is:issue
"k3+"`](https://github.com/verl-project/verl/issues?q=is%3Aissue+%22k3%2B%22)
— no related issue
- [x] Format the PR title as `[{modules}] {type}: {description}` —
`[algo] fix: …`. Single-line behavior fix in
`verl/trainer/ppo/core_algos.py`, no API surface change.
### Test
Added two CPU regression tests in
`tests/trainer/ppo/test_core_algos_on_cpu.py` that get picked up
automatically by `cpu_unit_tests.yml` (`tests/**/test_*_on_cpu.py`):
1. `test_kl_penalty_straight_through_value_matches_base` — parametrized
over `(k1+, kl+, abs+, k3+, low_var_kl+)`, asserts that the suffixed
estimator returns the same forward value as its base.
2. `test_kl_penalty_k3_plus_uses_k2_gradient` — asserts that the
gradient w.r.t. `logprob` produced by `k3+` is exactly the gradient
produced by `k2`, i.e. the straight-through trick is wired correctly.
Local run on top of the change:
```text
$ pytest -xvs tests/trainer/ppo/test_core_algos_on_cpu.py
============================= test session starts ==============================
collected 22 items
tests/trainer/ppo/test_core_algos_on_cpu.py ...................... [100%]
============================== 22 passed in 4.79s ==============================
```
Without the fix, every `..._straight_through_value_matches_base[*]`
case raises `NotImplementedError` in `kl_penalty_forward`.
### API and Usage Example
No API change. The fix simply makes the previously-documented `+`
suffix work as advertised. After this PR the following configurations
are usable (they currently crash on step 1):
```bash
# KL-in-reward with the straight-through trick:
python -m verl.trainer.main_ppo \
algorithm.use_kl_in_reward=True \
algorithm.kl_penalty=k3+ \
algorithm.kl_ctrl.kl_coef=1e-3
# Actor KL loss with the straight-through trick:
python -m verl.trainer.main_ppo \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_type=low_var_kl+ \
actor_rollout_ref.actor.kl_loss_coef=1e-3
```
### Design & Code Changes
Root cause is a single-line oversight in `core_algos.kl_penalty`:
```python
# before
forward_score = kl_penalty_forward(logprob, ref_logprob, kl_penalty)
if not kl_penalty.endswith("+") or kl_penalty in ("mse", "k2"):
return forward_score
```
`kl_penalty_forward` only recognizes the bare names (`k1`, `kl`, `abs`,
`mse`, `k2`, `k3`, `low_var_kl`, `full`), so it falls through to
`raise NotImplementedError` whenever the input ends with `+`.
Fix: strip the optional suffix before dispatch, leaving the outer
`endswith("+")` switch untouched so the straight-through path is still
selected correctly:
```python
# after
base_kl_penalty = kl_penalty[:-1] if kl_penalty.endswith("+") else kl_penalty
forward_score = kl_penalty_forward(logprob, ref_logprob, base_kl_penalty)
if not kl_penalty.endswith("+") or kl_penalty in ("mse", "k2"):
return forward_score
```
Files changed:
- `verl/trainer/ppo/core_algos.py` — 1-line fix + 1-line comment.
- `tests/trainer/ppo/test_core_algos_on_cpu.py` — 2 regression tests.
### Checklist Before Submitting
- [x] Read the [Contribute
Guide](https://github.com/verl-project/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/verl-project/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always` — all 12 hooks pass on the touched files (`ruff`,
`ruff-format`, `mypy`, `autogen-trainer-cfg`, `check-docs-time-info`,
`check-docstrings`, `check-license`, `check-device-api-usage`,
`check-dataproto-usage`, `validate-structure`,
`check-naming-conventions`, `compileall`).
- [ ] Add / Update [the
documentation](https://github.com/verl-project/verl/tree/main/docs) —
N/A (bug fix; the suffix was already documented in the `kl_penalty`
docstring).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/verl-project/verl/tree/main/.github/workflows).
The new tests are added to `tests/trainer/ppo/test_core_algos_on_cpu.py`
and are picked up automatically by `cpu_unit_tests.yml`
(`tests/**/test_*_on_cpu.py`).
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1).
- [x] Not related to the `recipe` submodule.1 parent 365df24 commit 7e80ab0
2 files changed
Lines changed: 51 additions & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
27 | 27 | | |
28 | 28 | | |
29 | 29 | | |
| 30 | + | |
30 | 31 | | |
31 | 32 | | |
32 | 33 | | |
| |||
313 | 314 | | |
314 | 315 | | |
315 | 316 | | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
316 | 364 | | |
317 | 365 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2135 | 2135 | | |
2136 | 2136 | | |
2137 | 2137 | | |
2138 | | - | |
| 2138 | + | |
| 2139 | + | |
| 2140 | + | |
2139 | 2141 | | |
2140 | 2142 | | |
2141 | 2143 | | |
| |||
0 commit comments