Skip to content

Commit 7e80ab0

Browse files
authored
[algo] fix: strip '+' suffix in kl_penalty so k3+/low_var_kl+ work (#6058)
### What does this PR do? Fix a long-standing bug in `core_algos.kl_penalty` that makes every `+`-suffixed estimator name (`k1+`, `kl+`, `abs+`, `k3+`, `low_var_kl+`) crash with `NotImplementedError` on the first training step that touches KL — either through `algorithm.kl_penalty` (KL-in-reward) or `actor.kl_loss_type` (actor KL loss). The straight-through trick that the `+` suffix is supposed to enable (unbiased k3 value with unbiased k2 gradient, see [Schulman, *Approximating KL divergence*](http://joschu.net/blog/kl-approx.html)) was added in #2953 but has never actually worked end-to-end because the suffix was forwarded to `kl_penalty_forward` without being stripped, causing the dispatch to fall through to `raise NotImplementedError`. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: - [`is:pr kl_penalty k3+`](https://github.com/verl-project/verl/pulls?q=is%3Apr+kl_penalty+k3%2B) — 0 open, 2 closed (docstring typo and a rename, neither addresses this bug) - [`is:pr kl_penalty endswith`](https://github.com/verl-project/verl/pulls?q=is%3Apr+kl_penalty+endswith) — 0 results - [`is:pr kl_penalty_forward`](https://github.com/verl-project/verl/pulls?q=is%3Apr+%22kl_penalty_forward%22) — 0 results - [`is:issue "k3+"`](https://github.com/verl-project/verl/issues?q=is%3Aissue+%22k3%2B%22) — no related issue - [x] Format the PR title as `[{modules}] {type}: {description}` — `[algo] fix: …`. Single-line behavior fix in `verl/trainer/ppo/core_algos.py`, no API surface change. ### Test Added two CPU regression tests in `tests/trainer/ppo/test_core_algos_on_cpu.py` that get picked up automatically by `cpu_unit_tests.yml` (`tests/**/test_*_on_cpu.py`): 1. `test_kl_penalty_straight_through_value_matches_base` — parametrized over `(k1+, kl+, abs+, k3+, low_var_kl+)`, asserts that the suffixed estimator returns the same forward value as its base. 2. `test_kl_penalty_k3_plus_uses_k2_gradient` — asserts that the gradient w.r.t. `logprob` produced by `k3+` is exactly the gradient produced by `k2`, i.e. the straight-through trick is wired correctly. Local run on top of the change: ```text $ pytest -xvs tests/trainer/ppo/test_core_algos_on_cpu.py ============================= test session starts ============================== collected 22 items tests/trainer/ppo/test_core_algos_on_cpu.py ...................... [100%] ============================== 22 passed in 4.79s ============================== ``` Without the fix, every `..._straight_through_value_matches_base[*]` case raises `NotImplementedError` in `kl_penalty_forward`. ### API and Usage Example No API change. The fix simply makes the previously-documented `+` suffix work as advertised. After this PR the following configurations are usable (they currently crash on step 1): ```bash # KL-in-reward with the straight-through trick: python -m verl.trainer.main_ppo \ algorithm.use_kl_in_reward=True \ algorithm.kl_penalty=k3+ \ algorithm.kl_ctrl.kl_coef=1e-3 # Actor KL loss with the straight-through trick: python -m verl.trainer.main_ppo \ actor_rollout_ref.actor.use_kl_loss=True \ actor_rollout_ref.actor.kl_loss_type=low_var_kl+ \ actor_rollout_ref.actor.kl_loss_coef=1e-3 ``` ### Design & Code Changes Root cause is a single-line oversight in `core_algos.kl_penalty`: ```python # before forward_score = kl_penalty_forward(logprob, ref_logprob, kl_penalty) if not kl_penalty.endswith("+") or kl_penalty in ("mse", "k2"): return forward_score ``` `kl_penalty_forward` only recognizes the bare names (`k1`, `kl`, `abs`, `mse`, `k2`, `k3`, `low_var_kl`, `full`), so it falls through to `raise NotImplementedError` whenever the input ends with `+`. Fix: strip the optional suffix before dispatch, leaving the outer `endswith("+")` switch untouched so the straight-through path is still selected correctly: ```python # after base_kl_penalty = kl_penalty[:-1] if kl_penalty.endswith("+") else kl_penalty forward_score = kl_penalty_forward(logprob, ref_logprob, base_kl_penalty) if not kl_penalty.endswith("+") or kl_penalty in ("mse", "k2"): return forward_score ``` Files changed: - `verl/trainer/ppo/core_algos.py` — 1-line fix + 1-line comment. - `tests/trainer/ppo/test_core_algos_on_cpu.py` — 2 regression tests. ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/verl-project/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/verl-project/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` — all 12 hooks pass on the touched files (`ruff`, `ruff-format`, `mypy`, `autogen-trainer-cfg`, `check-docs-time-info`, `check-docstrings`, `check-license`, `check-device-api-usage`, `check-dataproto-usage`, `validate-structure`, `check-naming-conventions`, `compileall`). - [ ] Add / Update [the documentation](https://github.com/verl-project/verl/tree/main/docs) — N/A (bug fix; the suffix was already documented in the `kl_penalty` docstring). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/verl-project/verl/tree/main/.github/workflows). The new tests are added to `tests/trainer/ppo/test_core_algos_on_cpu.py` and are picked up automatically by `cpu_unit_tests.yml` (`tests/**/test_*_on_cpu.py`). - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1). - [x] Not related to the `recipe` submodule.
1 parent 365df24 commit 7e80ab0

2 files changed

Lines changed: 51 additions & 1 deletion

File tree

tests/trainer/ppo/test_core_algos_on_cpu.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
compute_rloo_outcome_advantage,
2828
compute_rloo_vectorized_outcome_advantage,
2929
get_adv_estimator_fn,
30+
kl_penalty,
3031
register_adv_est,
3132
)
3233

@@ -313,5 +314,52 @@ def test_grpo_and_vectorized_equivalence(batch_size: int, seq_len: int, num_grou
313314
assert torch.allclose(ret1, ret2, rtol=1e-5, atol=1e-6)
314315

315316

317+
@pytest.mark.parametrize(
318+
"name,base",
319+
[
320+
("k1+", "k1"),
321+
("kl+", "kl"),
322+
("abs+", "abs"),
323+
("k3+", "k3"),
324+
("low_var_kl+", "low_var_kl"),
325+
],
326+
)
327+
def test_kl_penalty_straight_through_value_matches_base(name, base):
328+
"""The ``+`` suffix is a straight-through trick that swaps in the k2
329+
gradient while keeping the base estimator's value. Therefore the forward
330+
value of e.g. ``k3+`` must match the value of plain ``k3``.
331+
332+
Regression test for the bug where ``kl_penalty(..., "k3+")`` raised
333+
``NotImplementedError`` because the wrapper forwarded the ``+`` suffix to
334+
``kl_penalty_forward`` without stripping it.
335+
"""
336+
torch.manual_seed(0)
337+
logprob = torch.randn(4, 8, requires_grad=True)
338+
ref_logprob = torch.randn(4, 8)
339+
340+
plus_value = kl_penalty(logprob, ref_logprob, name)
341+
base_value = kl_penalty(logprob, ref_logprob, base)
342+
assert torch.allclose(plus_value, base_value)
343+
344+
345+
def test_kl_penalty_k3_plus_uses_k2_gradient():
346+
"""With ``k3+`` the gradient w.r.t. ``logprob`` should equal the gradient
347+
obtained from the ``k2`` (``0.5 * log_ratio**2``) estimator, since the
348+
straight-through trick routes the backward pass through ``k2``.
349+
"""
350+
torch.manual_seed(0)
351+
logprob = torch.randn(4, 8, requires_grad=True)
352+
ref_logprob = torch.randn(4, 8)
353+
354+
out_plus = kl_penalty(logprob, ref_logprob, "k3+").sum()
355+
(grad_plus,) = torch.autograd.grad(out_plus, logprob)
356+
357+
logprob_k2 = logprob.detach().clone().requires_grad_(True)
358+
out_k2 = kl_penalty(logprob_k2, ref_logprob, "k2").sum()
359+
(grad_k2,) = torch.autograd.grad(out_k2, logprob_k2)
360+
361+
assert torch.allclose(grad_plus, grad_k2)
362+
363+
316364
if __name__ == "__main__":
317365
unittest.main()

verl/trainer/ppo/core_algos.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2135,7 +2135,9 @@ def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_pe
21352135
Returns:
21362136
kl_estimate
21372137
"""
2138-
forward_score = kl_penalty_forward(logprob, ref_logprob, kl_penalty)
2138+
# Strip the optional '+' suffix so e.g. "k3+" dispatches to "k3".
2139+
base_kl_penalty = kl_penalty[:-1] if kl_penalty.endswith("+") else kl_penalty
2140+
forward_score = kl_penalty_forward(logprob, ref_logprob, base_kl_penalty)
21392141
if not kl_penalty.endswith("+") or kl_penalty in ("mse", "k2"):
21402142
return forward_score
21412143

0 commit comments

Comments
 (0)