diff --git a/tests/trainer/ppo/test_core_algos_on_cpu.py b/tests/trainer/ppo/test_core_algos_on_cpu.py index 288f28e6398..e25169cd65f 100644 --- a/tests/trainer/ppo/test_core_algos_on_cpu.py +++ b/tests/trainer/ppo/test_core_algos_on_cpu.py @@ -27,6 +27,7 @@ compute_rloo_outcome_advantage, compute_rloo_vectorized_outcome_advantage, get_adv_estimator_fn, + kl_penalty, register_adv_est, ) @@ -313,5 +314,52 @@ def test_grpo_and_vectorized_equivalence(batch_size: int, seq_len: int, num_grou assert torch.allclose(ret1, ret2, rtol=1e-5, atol=1e-6) +@pytest.mark.parametrize( + "name,base", + [ + ("k1+", "k1"), + ("kl+", "kl"), + ("abs+", "abs"), + ("k3+", "k3"), + ("low_var_kl+", "low_var_kl"), + ], +) +def test_kl_penalty_straight_through_value_matches_base(name, base): + """The ``+`` suffix is a straight-through trick that swaps in the k2 + gradient while keeping the base estimator's value. Therefore the forward + value of e.g. ``k3+`` must match the value of plain ``k3``. + + Regression test for the bug where ``kl_penalty(..., "k3+")`` raised + ``NotImplementedError`` because the wrapper forwarded the ``+`` suffix to + ``kl_penalty_forward`` without stripping it. + """ + torch.manual_seed(0) + logprob = torch.randn(4, 8, requires_grad=True) + ref_logprob = torch.randn(4, 8) + + plus_value = kl_penalty(logprob, ref_logprob, name) + base_value = kl_penalty(logprob, ref_logprob, base) + assert torch.allclose(plus_value, base_value) + + +def test_kl_penalty_k3_plus_uses_k2_gradient(): + """With ``k3+`` the gradient w.r.t. ``logprob`` should equal the gradient + obtained from the ``k2`` (``0.5 * log_ratio**2``) estimator, since the + straight-through trick routes the backward pass through ``k2``. + """ + torch.manual_seed(0) + logprob = torch.randn(4, 8, requires_grad=True) + ref_logprob = torch.randn(4, 8) + + out_plus = kl_penalty(logprob, ref_logprob, "k3+").sum() + (grad_plus,) = torch.autograd.grad(out_plus, logprob) + + logprob_k2 = logprob.detach().clone().requires_grad_(True) + out_k2 = kl_penalty(logprob_k2, ref_logprob, "k2").sum() + (grad_k2,) = torch.autograd.grad(out_k2, logprob_k2) + + assert torch.allclose(grad_plus, grad_k2) + + if __name__ == "__main__": unittest.main() diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index e777c0903c0..227c1afb156 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -2135,7 +2135,9 @@ def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_pe Returns: kl_estimate """ - forward_score = kl_penalty_forward(logprob, ref_logprob, kl_penalty) + # Strip the optional '+' suffix so e.g. "k3+" dispatches to "k3". + base_kl_penalty = kl_penalty[:-1] if kl_penalty.endswith("+") else kl_penalty + forward_score = kl_penalty_forward(logprob, ref_logprob, base_kl_penalty) if not kl_penalty.endswith("+") or kl_penalty in ("mse", "k2"): return forward_score