Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions tests/trainer/ppo/test_core_algos_on_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
compute_rloo_outcome_advantage,
compute_rloo_vectorized_outcome_advantage,
get_adv_estimator_fn,
kl_penalty,
register_adv_est,
)

Expand Down Expand Up @@ -313,5 +314,52 @@ def test_grpo_and_vectorized_equivalence(batch_size: int, seq_len: int, num_grou
assert torch.allclose(ret1, ret2, rtol=1e-5, atol=1e-6)


@pytest.mark.parametrize(
"name,base",
[
("k1+", "k1"),
("kl+", "kl"),
("abs+", "abs"),
("k3+", "k3"),
("low_var_kl+", "low_var_kl"),
],
)
def test_kl_penalty_straight_through_value_matches_base(name, base):
"""The ``+`` suffix is a straight-through trick that swaps in the k2
gradient while keeping the base estimator's value. Therefore the forward
value of e.g. ``k3+`` must match the value of plain ``k3``.

Regression test for the bug where ``kl_penalty(..., "k3+")`` raised
``NotImplementedError`` because the wrapper forwarded the ``+`` suffix to
``kl_penalty_forward`` without stripping it.
"""
torch.manual_seed(0)
logprob = torch.randn(4, 8, requires_grad=True)
ref_logprob = torch.randn(4, 8)

plus_value = kl_penalty(logprob, ref_logprob, name)
base_value = kl_penalty(logprob, ref_logprob, base)
assert torch.allclose(plus_value, base_value)


def test_kl_penalty_k3_plus_uses_k2_gradient():
"""With ``k3+`` the gradient w.r.t. ``logprob`` should equal the gradient
obtained from the ``k2`` (``0.5 * log_ratio**2``) estimator, since the
straight-through trick routes the backward pass through ``k2``.
"""
torch.manual_seed(0)
logprob = torch.randn(4, 8, requires_grad=True)
ref_logprob = torch.randn(4, 8)

out_plus = kl_penalty(logprob, ref_logprob, "k3+").sum()
(grad_plus,) = torch.autograd.grad(out_plus, logprob)

logprob_k2 = logprob.detach().clone().requires_grad_(True)
out_k2 = kl_penalty(logprob_k2, ref_logprob, "k2").sum()
(grad_k2,) = torch.autograd.grad(out_k2, logprob_k2)

assert torch.allclose(grad_plus, grad_k2)


if __name__ == "__main__":
unittest.main()
4 changes: 3 additions & 1 deletion verl/trainer/ppo/core_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2135,7 +2135,9 @@ def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_pe
Returns:
kl_estimate
"""
forward_score = kl_penalty_forward(logprob, ref_logprob, kl_penalty)
# Strip the optional '+' suffix so e.g. "k3+" dispatches to "k3".
base_kl_penalty = kl_penalty[:-1] if kl_penalty.endswith("+") else kl_penalty
forward_score = kl_penalty_forward(logprob, ref_logprob, base_kl_penalty)
if not kl_penalty.endswith("+") or kl_penalty in ("mse", "k2"):
return forward_score

Expand Down
Loading