Skip to content

Commit 38d2391

Browse files
formatting
1 parent 3a55325 commit 38d2391

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

verl/trainer/ppo/core_algos.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,7 +1282,7 @@ def compute_value_loss(
12821282

12831283

12841284
def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:
1285-
"""Compute KL divergence given logprob and ref_logprob. Optionally using straight through to bind k2 on other
1285+
"""Compute KL divergence given logprob and ref_logprob. Optionally using straight through to bind k2 on other
12861286
kl penalty compute method for unbiased KL gradient estimation.
12871287
See more description in http://joschu.net/blog/kl-approx.html
12881288
@@ -1296,7 +1296,7 @@ def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_pe
12961296
forward_score = kl_penalty_forward(logprob, ref_logprob, kl_penalty)
12971297
if not kl_penalty.endswith("+") or kl_penalty in ("mse", "k2"):
12981298
return forward_score
1299-
1299+
13001300
"""
13011301
The expectation of k1 and k3 estimator is the expectaed value of KL, but the expected gradient of k1 and k3
13021302
estimator is not the expectaed gradient of KL. On the other hand k2 estimator gives right gradient estimator,

0 commit comments

Comments
 (0)