forked from verl-project/verl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcore_algos.py
More file actions
2487 lines (2072 loc) · 99.4 KB
/
core_algos.py
File metadata and controls
2487 lines (2072 loc) · 99.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Core functions to implement PPO algorithms.
The function implemented in this file should be used by trainer with different distributed strategies to
implement PPO-like algorithms.
"""
__all__ = ["register_adv_est", "get_adv_estimator_fn", "AdvantageEstimator"]
from collections import defaultdict
from enum import Enum
from typing import Any, Callable, Optional
import numpy as np
import torch
from omegaconf import DictConfig
import verl.utils.torch_functional as verl_F
from verl.trainer.config import AlgoConfig
from verl.utils import as_torch_index, group_mean_std
from verl.utils.import_utils import deprecated
from verl.workers.config import ActorConfig
PolicyLossFn = Callable[
[
torch.Tensor, # old_log_prob
torch.Tensor, # log_prob
torch.Tensor, # advantages
torch.Tensor, # response_mask
str, # loss_agg_mode
Optional[DictConfig | ActorConfig], # config
torch.Tensor | None, # rollout_log_probs
],
tuple[torch.Tensor, dict[str, Any]],
]
POLICY_LOSS_REGISTRY: dict[str, PolicyLossFn] = {}
def register_policy_loss(name: str) -> Callable[[PolicyLossFn], PolicyLossFn]:
"""Register a policy loss function with the given name.
Args:
name (str): The name to register the policy loss function under.
Returns:
function: Decorator function that registers the policy loss function.
"""
def decorator(func: PolicyLossFn) -> PolicyLossFn:
POLICY_LOSS_REGISTRY[name] = func
return func
return decorator
def get_policy_loss_fn(name):
"""Get the policy loss with a given name.
Args:
name: `(str)`
The name of the policy loss.
Returns:
`(callable)`: The policy loss function.
"""
loss_name = name
if loss_name not in POLICY_LOSS_REGISTRY:
raise ValueError(
f"Unsupported loss mode: {loss_name}. Supported modes are: {list(POLICY_LOSS_REGISTRY.keys())}"
)
return POLICY_LOSS_REGISTRY[loss_name]
class AdvantageEstimator(str, Enum):
"""Using an enumeration class to avoid spelling errors in adv_estimator.
Note(haibin.lin): this enum class is immutable after creation. Extending this
enum for new estimators may not be necessary since users can always just call
`verl.trainer.ppo.core_algos.register` with string name for a custom advantage
estimator instead.
"""
GAE = "gae"
GRPO = "grpo"
REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline"
REMAX = "remax"
RLOO = "rloo"
OPO = "opo"
GRPO_PASSK = "grpo_passk"
GPG = "gpg"
RLOO_VECTORIZED = "rloo_vectorized"
GRPO_VECTORIZED = "grpo_vectorized"
OPTIMAL_TOKEN_BASELINE = "optimal_token_baseline"
TIR_OPTIMAL_TOKEN_BASELINE = "tir_optimal_token_baseline"
GDPO = "gdpo"
ADV_ESTIMATOR_REGISTRY: dict[str, Any] = {}
def register_adv_est(name_or_enum: str | AdvantageEstimator) -> Any:
"""Decorator to register a advantage estimator function with a given name.
Args:
name_or_enum: `(str)` or `(AdvantageEstimator)`
The name or enum of the advantage estimator.
"""
def decorator(fn):
name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum
if name in ADV_ESTIMATOR_REGISTRY and ADV_ESTIMATOR_REGISTRY[name] != fn:
raise ValueError(
f"Adv estimator {name} has already been registered: {ADV_ESTIMATOR_REGISTRY[name]} vs {fn}"
)
ADV_ESTIMATOR_REGISTRY[name] = fn
return fn
return decorator
def get_adv_estimator_fn(name_or_enum):
"""Get the advantage estimator function with a given name.
Args:
name_or_enum: `(str)` or `(AdvantageEstimator)`
The name or enum of the advantage estimator.
Returns:
`(callable)`: The advantage estimator function.
"""
name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum
if name not in ADV_ESTIMATOR_REGISTRY:
raise ValueError(f"Unknown advantage estimator simply: {name}")
return ADV_ESTIMATOR_REGISTRY[name]
class AdaptiveKLController:
"""
Adaptive KL controller described in the paper:
https://arxiv.org/pdf/1909.08593.pdf
"""
def __init__(self, init_kl_coef, target_kl, horizon):
self.value = init_kl_coef
self.target = target_kl
self.horizon = horizon
def update(self, current_kl, n_steps):
"""Update the KL coefficient based on current KL divergence.
Args:
current_kl (float): Current KL divergence value.
n_steps (int): Number of steps taken.
"""
target = self.target
proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
mult = 1 + proportional_error * n_steps / self.horizon
self.value *= mult
class FixedKLController:
"""Fixed KL controller."""
def __init__(self, kl_coef):
self.value = kl_coef
def update(self, current_kl, n_steps):
"""Update method for fixed KL controller (no-op).
Args:
current_kl (float): Current KL divergence value (unused).
n_steps (int): Number of steps taken (unused).
"""
pass
def get_kl_controller(kl_ctrl):
"""Factory function to create appropriate KL controller based on configuration.
Args:
kl_ctrl: Configuration object containing KL controller settings.
Returns:
KL controller instance (FixedKLController or AdaptiveKLController).
Raises:
NotImplementedError: If controller type is not supported.
AssertionError: If adaptive controller horizon is not positive.
"""
if kl_ctrl.type == "fixed":
return FixedKLController(kl_coef=kl_ctrl.kl_coef)
elif kl_ctrl.type == "adaptive":
assert kl_ctrl.horizon > 0, f"horizon must be larger than 0. Got {kl_ctrl.horizon}"
return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon)
else:
raise NotImplementedError
@register_adv_est(AdvantageEstimator.GAE) # or simply: @register_adv_est("gae")
def compute_gae_advantage_return(
token_level_rewards: torch.Tensor,
values: torch.Tensor,
response_mask: torch.Tensor,
gamma: torch.Tensor,
lam: torch.Tensor,
):
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py
Args:
token_level_rewards: `(torch.Tensor)`
shape is (bs, response_length)
values: `(torch.Tensor)`
shape is (bs, response_length)
response_mask: `(torch.Tensor)`
shape is (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.
gamma is `(float)`
discounted factor used in RL
lam: `(float)`
lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
with torch.no_grad():
nextvalues = 0
lastgaelam = 0
advantages_reversed = []
gen_len = token_level_rewards.shape[-1]
for t in reversed(range(gen_len)):
delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
lastgaelam_ = delta + gamma * lam * lastgaelam
# skip values and TD-error on observation tokens
nextvalues = values[:, t] * response_mask[:, t] + (1 - response_mask[:, t]) * nextvalues
lastgaelam = lastgaelam_ * response_mask[:, t] + (1 - response_mask[:, t]) * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
returns = advantages + values
advantages = verl_F.masked_whiten(advantages, response_mask)
return advantages, returns
# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar.
@register_adv_est(AdvantageEstimator.GRPO) # or simply: @register_adv_est("grpo")
def compute_grpo_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
norm_adv_by_std_in_grpo: bool = True,
config: Optional[AlgoConfig] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for GRPO, operating only on Outcome reward
(with only one scalar reward for each response).
Args:
token_level_rewards: `(torch.Tensor)`
shape is (bs, response_length)
response_mask: `(torch.Tensor)`
shape is (bs, response_length)
index: `(np.ndarray)`
index array for grouping
epsilon: `(float)`
small value to avoid division by zero
norm_adv_by_std_in_grpo: `(bool)`
whether to scale the GRPO advantage
config: `(Optional[AlgoConfig])`
algorithm configuration object
Note:
If norm_adv_by_std_in_grpo is True, the advantage is scaled by the std, as in the original GRPO.
If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783).
Returns:
advantages: `(torch.Tensor)`
shape is (bs, response_length)
Returns: `(torch.Tensor)`
shape is (bs, response_length)
"""
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2mean = {}
id2std = {}
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
id2std[idx] = torch.tensor(1.0)
elif len(id2score[idx]) > 1:
scores_tensor = torch.stack(id2score[idx])
id2mean[idx] = torch.mean(scores_tensor)
id2std[idx] = torch.std(scores_tensor)
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
if norm_adv_by_std_in_grpo:
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
else:
scores[i] = scores[i] - id2mean[index[i]]
scores = scores.unsqueeze(-1) * response_mask
return scores, scores
@register_adv_est(AdvantageEstimator.GRPO_VECTORIZED)
def compute_grpo_vectorized_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
norm_adv_by_std_in_grpo: bool = True,
config: Optional[AlgoConfig] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Vectorized GRPO(outcome-only):
For each group g:
a_i = \\frac{r_i - \\mu_g}{\\sigma_g} (or without dividing by \\sigma_g),
then broadcast the scalar across the token dimension (multiplied by response_mask).。
"""
with torch.no_grad():
scores = token_level_rewards.sum(dim=-1)
g = as_torch_index(index, device=scores.device)
mean_g, std_g, _ = group_mean_std(scores, g, eps=epsilon, device=scores.device)
if norm_adv_by_std_in_grpo:
scalars = (scores - mean_g[g]) / (std_g[g] + epsilon)
else:
scalars = scores - mean_g[g]
advantages = scalars.unsqueeze(-1) * response_mask
return advantages, advantages
@register_adv_est(AdvantageEstimator.GDPO) # or simply: @register_adv_est("gdpo")
def compute_gdpo_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
norm_adv_by_std_in_grpo: bool = True,
config: Optional[AlgoConfig] = None,
non_tensor_batch: Optional[dict] = None,
batch: Optional[dict] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
GDPO: Group reward-Decoupled Normalization Policy Optimization.
Instead of summing all reward dimensions first (like GRPO), GDPO normalizes
each reward dimension independently within each group before aggregation.
This prevents a dominant reward signal from drowning out weaker ones.
Mathematical formulation:
Step 1 – Group-wise decoupled normalization (via GRPO per dimension):
For each reward dimension k, within each group g:
A_k = (r_k - μ_group(r_k)) / (σ_group(r_k) + ε)
Step 2 – Weighted aggregation:
A_sum = Σ_k w_k · A_k
Step 3 – Batch-level normalization (via masked_whiten):
A_final = whiten(A_sum, response_mask)
Args:
token_level_rewards: (bs, response_length) – standard token-level rewards.
Used as fallback when per-dimension rewards are not provided.
response_mask: (bs, response_length)
index: (bs,) – group id per sample (from ``uid``).
epsilon: Numerical stability constant.
norm_adv_by_std_in_grpo: Whether to normalize by std in GRPO.
config: Algorithm configuration (optional).
non_tensor_batch: Non-tensor batch data containing per-dimension reward scores.
batch: Batch data containing prompts, attention_mask, etc.
Note:
Ref GDPO (https://arxiv.org/abs/2601.05242).
Returns:
advantages: (bs, response_length)
returns: (bs, response_length) – same as advantages (outcome-only).
"""
score_list = None
reward_weights = None
if config is not None and non_tensor_batch is not None and batch is not None:
gdpo_reward_keys = config.get("gdpo_reward_keys", None)
assert gdpo_reward_keys, (
"GDPO requires 'algorithm.gdpo_reward_keys' listing the individual reward "
"component keys returned by compute_score (e.g. ['format_reward', 'accuracy_reward'])."
)
device = token_level_rewards.device
prompt_length = batch["prompts"].size(1)
valid_response_length = batch["attention_mask"][:, prompt_length:].sum(dim=1) - 1
score_list = []
for key in gdpo_reward_keys:
assert key in non_tensor_batch, (
f"GDPO reward key '{key}' not found in non_tensor_batch. "
f"Available keys: {list(non_tensor_batch.keys())}. "
f"Make sure your compute_score returns a dict containing '{key}'."
)
comp = non_tensor_batch[key]
rm_score = torch.tensor(np.asarray(comp, dtype=np.float32), device=device)
rm_scores = torch.zeros_like(response_mask, dtype=torch.float32)
rm_scores[torch.arange(rm_scores.size(0), device=device), valid_response_length] = rm_score
score_list.append(rm_scores)
gdpo_weights = config.get("gdpo_reward_weights", None)
if gdpo_weights is not None:
reward_weights = list(gdpo_weights)
if score_list is None:
score_list = [token_level_rewards]
num_scores = len(score_list)
if reward_weights is not None:
weights = torch.tensor(reward_weights, dtype=torch.float32, device=token_level_rewards.device)
else:
weights = torch.ones(num_scores, dtype=torch.float32, device=token_level_rewards.device)
new_advantage = None
for i in range(num_scores):
normalized_score, _ = compute_grpo_outcome_advantage(
token_level_rewards=score_list[i],
response_mask=response_mask,
index=index,
epsilon=epsilon,
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
config=config,
)
if new_advantage is None:
new_advantage = weights[i] * normalized_score
else:
new_advantage += weights[i] * normalized_score
advantages = verl_F.masked_whiten(new_advantage, response_mask) * response_mask
return advantages, advantages
@register_adv_est(AdvantageEstimator.GRPO_PASSK) # or simply: @register_adv_est("grpo_passk")
def compute_grpo_passk_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
norm_adv_by_std_in_grpo: bool = True,
config: Optional[AlgoConfig] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for Pass@k using a GRPO-style outcome reward formulation.
Only the best response per group gets a non-zero advantage: r_max - r_second_max.
Implemented as described in https://arxiv.org/abs/2503.19595.
Args:
token_level_rewards: (bs, response_length)
response_mask: (bs, response_length)
index: (bs,) → group ID per sample
epsilon: float for numerical stability
config: (AlgoConfig) algorithm settings, which contains "norm_adv_by_std_in_grpo"
Returns:
advantages: (bs, response_length)
returns: (bs, response_length)
"""
assert config is not None
# if True, normalize advantage by std within group
norm_adv_by_std_in_grpo = config.get("norm_adv_by_std_in_grpo", True)
scores = token_level_rewards.sum(dim=-1) # (bs,)
advantages = torch.zeros_like(scores)
id2scores = defaultdict(list)
id2indices = defaultdict(list)
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
idx = index[i]
id2scores[idx].append(scores[i])
id2indices[idx].append(i)
for idx in id2scores:
rewards = torch.stack(id2scores[idx]) # (k,)
if rewards.numel() < 2:
raise ValueError(
f"Pass@k requires at least 2 samples per group. Got {rewards.numel()} for group {idx}."
)
topk, topk_idx = torch.topk(rewards, 2)
r_max, r_second_max = topk[0], topk[1]
i_max = id2indices[idx][topk_idx[0].item()]
advantage = r_max - r_second_max
if norm_adv_by_std_in_grpo:
std = torch.std(rewards)
advantage = advantage / (std + epsilon)
advantages[i_max] = advantage
advantages = advantages.unsqueeze(-1) * response_mask
return advantages, advantages
@register_adv_est(
AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE
) # or simply: @register_adv_est("reinforce_plus_plus_baseline")
def compute_reinforce_plus_plus_baseline_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: torch.Tensor,
epsilon: float = 1e-6,
config: Optional[AlgoConfig] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for RF++-baseline (https://arxiv.org/abs/2501.03262), operating only on Outcome reward
(with only one scalar reward for each response).
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
config: (AlgoConfig) algorithm config
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
response_length = token_level_rewards.shape[-1]
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2mean = {}
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.stack(id2score[idx]))
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
scores[i] = scores[i] - id2mean[index[i]]
scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask
scores = verl_F.masked_whiten(scores, response_mask) * response_mask
return scores, scores
@register_adv_est(AdvantageEstimator.RLOO) # or simply: @register_adv_est("rloo")
def compute_rloo_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
config: Optional[AlgoConfig] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
config: (AlgoConfig) algorithm config
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2mean = {}
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.stack(id2score[idx]))
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
response_num = len(id2score[index[i]])
if response_num > 1:
scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[index[i]] * response_num / (
response_num - 1
)
scores = scores.unsqueeze(-1) * response_mask
return scores, scores
@register_adv_est(AdvantageEstimator.OPO) # or simply: @register_adv_est("opo")
def compute_opo_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
config: Optional[AlgoConfig] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for OPO based on https://arxiv.org/pdf/2505.23585
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
config: (AlgoConfig) algorithm config
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
response_length = response_mask.sum(dim=-1)
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2len = defaultdict(list)
id2bsl = {}
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
id2len[index[i]].append(response_length[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2bsl[idx] = torch.tensor(0.0)
elif len(id2score[idx]) > 1:
score_tensor = torch.stack(id2score[idx])
len_tensor = torch.stack(id2len[idx])
id2bsl[idx] = (len_tensor * score_tensor).sum() / len_tensor.sum()
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
scores[i] = scores[i] - id2bsl[index[i]]
scores = scores.unsqueeze(-1) * response_mask
return scores, scores
@register_adv_est(AdvantageEstimator.REINFORCE_PLUS_PLUS) # or simply: @register_adv_est("reinforce_plus_plus")
def compute_reinforce_plus_plus_outcome_advantage(
token_level_rewards: torch.Tensor, response_mask: torch.Tensor, config: Optional[AlgoConfig] = None, **kwargs
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for REINFORCE++.
This implementation is based on the paper: https://arxiv.org/abs/2501.03262
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
config: (AlgoConfig) algorithm config
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
assert config is not None
gamma = config.gamma
with torch.no_grad():
returns = torch.zeros_like(token_level_rewards)
running_return = 0
for t in reversed(range(token_level_rewards.shape[1])):
running_return = token_level_rewards[:, t] + gamma * running_return
returns[:, t] = running_return
# Reset after EOS
running_return = running_return * response_mask[:, t]
advantages = verl_F.masked_whiten(returns, response_mask)
advantages = advantages * response_mask
return advantages, returns
@register_adv_est(AdvantageEstimator.REMAX) # or simply: @register_adv_est("remax")
def compute_remax_outcome_advantage(
token_level_rewards: torch.Tensor,
reward_baselines: torch.Tensor,
response_mask: torch.Tensor,
config: Optional[AlgoConfig] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for ReMax, operating only on Outcome reward
This implementation is based on the paper: https://arxiv.org/abs/2310.10505
(with only one scalar reward for each response).
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
reward_baselines: `(torch.Tensor)`
shape: (bs,)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
config: (AlgoConfig) algorithm config
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
with torch.no_grad():
returns = (token_level_rewards * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
advantages = returns - reward_baselines.unsqueeze(-1) * response_mask
return advantages, returns
@register_adv_est(AdvantageEstimator.GPG) # or simply: @register_adv_est("gpg")
def compute_gpg_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
f_norm: float = 1.0,
alpha: float = 1.0,
config=None,
**kwargs,
):
"""
Compute advantage for GPG, operating only on Outcome reward
(with only one scalar reward for each response).
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
index: `(np.ndarray)`
shape: (bs,)
epsilon: (float)
f_norm: (float)
alpha: (float)
config: (dict) algorithm config
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2mean = {}
id2std = {}
with torch.no_grad():
bsz = scores.shape[0]
m = torch.count_nonzero(scores)
alpha = bsz / m.clamp(min=1)
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
id2std[idx] = torch.tensor(1.0)
elif len(id2score[idx]) > 1:
scores_tensor = torch.stack(id2score[idx])
id2mean[idx] = torch.mean(scores_tensor)
id2std[idx] = torch.std(scores_tensor)
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
scores[i] = alpha * (scores[i] - id2mean[index[i]]) / (f_norm)
scores = scores.unsqueeze(-1) * response_mask
return scores, scores
@register_adv_est(AdvantageEstimator.RLOO_VECTORIZED) # or simply: @register_adv_est("rloo_vectorized")
def compute_rloo_vectorized_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
config: Optional[AlgoConfig] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
config: (AlgoConfig) algorithm config
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
scores = token_level_rewards.sum(dim=-1)
with torch.no_grad():
inv = torch.from_numpy(np.unique(index, return_inverse=True)[1]).to(scores.device)
c = torch.bincount(inv)[inv].to(scores.dtype)
adv = ((c * scores - torch.bincount(inv, weights=scores)[inv]) / (c - 1).clamp_min(1)) * (c > 1)
adv = adv.unsqueeze(-1) * response_mask
return adv, adv
@register_adv_est(AdvantageEstimator.OPTIMAL_TOKEN_BASELINE)
def compute_optimal_token_baseline_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
old_log_probs: torch.Tensor,
sum_pi_squared: torch.Tensor,
rollout_is_weights: torch.Tensor = None,
handle_zero_tail: bool = True,
epsilon: float = 1e-8,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantages using Optimal Token Baseline (OTB).
Unlike the group mean based baseline which uses a single baseline per trajectory,
this computes a unique baseline for each timestep using cumulative path variance.
Theory:
For each timestep t in each prompt group:
B_t* = E[G_t × W_t] / E[W_t]
where W_t = Σ_{j=1}^t ||s_j||² (cumulative path-variance proxy)
and ||s_j||² = 1 - 2π_j + Σπ²
The cumulative sum W_t captures the "realized energy" of trajectory has been up to timestep t,
giving higher weight to predicting rewards on high-variance paths.
Args:
token_level_rewards: Rewards at each token position [shape: (bs, response_length)]
response_mask: Binary mask for valid tokens (1) vs padding (0) [shape: (bs, response_length)]
index: Prompt indices for grouping trajectories from same prompt [shape: (bs,)]
old_log_probs: Log probabilities from training policy during generation [shape: (bs, response_length)]
sum_pi_squared: Sum of squared probabilities over vocabulary Σπ² [shape: (bs, response_length)]
rollout_is_weights: Pre-computed IS weights for W correction [shape: (bs, response_length)],
None if not using IS
handle_zero_tail: If True, zero baselines will be set in the portion of the longest trajectory
that extends beyond the second-longest trajectory in the prompt group.
Default: True
epsilon: Small constant for numerical stability (default: 1e-8)
Returns:
advantages: OTB advantage estimates [shape: (bs, response_length)]
returns: Cumulative rewards (returns) from each position [shape: (bs, response_length)]
Note on Rollout Importance Sampling:
When rollout_is_weights is provided, W_t is scaled by ρ̄²(t) to minimize MSE under truncated IS:
B_t* = Σ[G_t × ρ̄²(t) × W_t] / Σ[ρ̄²(t) × W_t]
"""
with torch.no_grad():
batch_size, seq_len = token_level_rewards.shape
device = token_level_rewards.device
# Compute returns (reward-to-go) for each timestep
returns = (token_level_rewards * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
# Step 1: Compute w_per_timestep = 1 - 2π_t + Σπ²)
pi_t = torch.exp(old_log_probs)
w_per_timestep = 1 - 2 * pi_t + sum_pi_squared
# Step 2: Apply rollout importance sampling correction (if enabled)
if rollout_is_weights is not None:
# Scale W by ρ̄² to minimize MSE under truncated IS
w_per_timestep = w_per_timestep * (rollout_is_weights**2)
# Step 3: Compute cumulative path-variance proxy: W_t = Σ_{j=1}^t w_j
# This measures accumulated variance from the start of the trajectory up to timestep t
w_cumulative = (w_per_timestep * response_mask).cumsum(dim=-1)
# Group trajectories by prompt
prompt_groups = defaultdict(list)
for i in range(batch_size):
prompt_groups[index[i]].append(i)
# Initialize baselines tensor [batch_size, seq_len]
baselines = torch.zeros_like(returns)
# Compute per-step baseline for each prompt group
for _, trajectory_indices in prompt_groups.items():
N = len(trajectory_indices)
if N == 1:
# Single trajectory - no baseline (advantage = return)
continue
traj_idx = torch.tensor(trajectory_indices, device=device)
# Extract group data [N, seq_len]
returns_group = returns[traj_idx]
w_cumulative_group = w_cumulative[traj_idx]
mask_group = response_mask[traj_idx]
# Compute per-timestep baseline: B_t = Σ[G_t × W_t] / Σ[W_t]
# where W_t = Σ_{j=1}^t ||s_j||² (cumulative path variance)
# Shape: [seq_len]
numerator = (returns_group * w_cumulative_group * mask_group).sum(dim=0) # Sum over trajectories
denominator = (w_cumulative_group * mask_group).sum(dim=0) + epsilon
baseline_per_step = numerator / denominator # [seq_len]
# Assign to all trajectories in this group
baselines[traj_idx] = baseline_per_step.unsqueeze(0).expand(N, -1)
if handle_zero_tail:
# Optionally zero out the portion of the longest trajectory that extends
# beyond the second-longest trajectory in the prompt group.
response_lengths = mask_group.sum(dim=-1)
sorted_lengths, _ = torch.sort(response_lengths)
max_length = int(sorted_lengths[-1].item())
second_max_length = int(sorted_lengths[-2].item())
max_length_idx = (response_lengths == max_length).nonzero(as_tuple=True)[0]
if max_length_idx.numel() == 1 and max_length > second_max_length:
max_length_traj_idx = trajectory_indices[int(max_length_idx[0])]
baselines[max_length_traj_idx, second_max_length:] = 0.0
# Compute advantages: A_t = G_t - B_t
advantages = (returns - baselines) * response_mask
return advantages, returns
@register_adv_est(AdvantageEstimator.TIR_OPTIMAL_TOKEN_BASELINE)
def compute_multi_turn_optimal_token_baseline_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
old_log_probs: torch.Tensor,
sum_pi_squared: torch.Tensor,
rollout_is_weights: torch.Tensor = None,
handle_zero_tail: bool = True,
epsilon: float = 1e-8,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""