Skip to content

Commit 7bc86c1

Browse files
[algo] fix: Add seq mean mask denominator option (verl-project#4510)
## Summary Refactor `agg_loss` function and fix entropy/KL loss scaling in distributed training. **Changes:** - **Refactor**: Unify `seq-mean-*` modes with shared denominator logic using `masked_sum` - **Behavior change**: `seq-mean-token-sum-norm` now applies seq-mean division (denominator = `global_batch_size * dp_size` or `local_bsz`), matching the mode name - **Simplification**: Remove fully-masked sequence exclusion from denominator; use total batch size consistently NOTE: Since the global loss aggregation logic is not compatible with the legacy model engine that conduct the aggregation outside `agg_loss` and is going to be deprecated, we keep this PR from modifying the the legacy model engine. ⚠️ **Breaking**: `seq-mean-token-sum-norm` now divides by both `loss_scale_factor` AND `seq_denominator`. Previously only divided by `loss_scale_factor`. ## Test plan - [ ] Verify PPO training with `seq-mean-token-sum` mode - [ ] Verify PPO training with `seq-mean-token-mean` mode - [ ] Verify PPO training with `seq-mean-token-sum-norm` mode (note: behavior changed) - [ ] Confirm entropy/KL loss values are correctly scaled in multi-GPU training --------- Co-authored-by: Shawn/Yuxuan Tong <tongyuxuan361@gmail.com>
1 parent 7290aef commit 7bc86c1

File tree

2 files changed

+35
-21
lines changed

2 files changed

+35
-21
lines changed

verl/trainer/ppo/core_algos.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -781,15 +781,21 @@ def agg_loss(
781781
"""
782782
Aggregate the loss across global batch to ensure the loss is invariant to fsdp/megatron parallelism.
783783
784+
NOTE: ``dp_size``, ``batch_num_tokens``, and ``global_batch_size`` are only compatible with the new model engine
785+
for now, while the legacy model engines conduct the aggregation outside ``agg_loss``.
786+
784787
NOTE: The returned loss has different behaviors for different backend:
785788
- FSDP: the loss is directly used for backward.
786789
- Megatron: the loss should be scaled by `num_microbatches` and `cp_size` for pp schedule.
787790
791+
# TODO: Consider the numerical stability?
792+
788793
Args:
789794
loss_mat: micro batch loss matrix, (bs, response_length)
790795
loss_mask: micro batch loss mask, (bs, response_length)
791796
loss_agg_mode: method to aggregate the loss matrix into a scalar
792-
dp_size: data parallel size
797+
dp_size: data parallel size. When appling manual aggregation,
798+
scaling up the ``loss`` by ``dp_size`` can cancel out FSDP averaging.
793799
batch_num_tokens: number of valid tokens in global batch
794800
global_batch_size: global batch size
795801
loss_scale_factor: scale factor for "seq-mean-token-sum-norm" mode. If None, uses loss_mask.shape[-1].
@@ -799,30 +805,39 @@ def agg_loss(
799805
loss: `a scalar torch.Tensor`
800806
aggregated loss
801807
"""
808+
# NOTE: `masked_sum` is more robust than multiplying the `mask`.
802809
if loss_agg_mode == "token-mean":
803810
if batch_num_tokens is None:
804811
batch_num_tokens = loss_mask.sum()
805812
loss = verl_F.masked_sum(loss_mat, loss_mask) / batch_num_tokens * dp_size
806-
elif loss_agg_mode == "seq-mean-token-sum":
807-
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum
808-
seq_mask = (torch.sum(loss_mask, dim=-1) > 0).float() # exclude fully masked sequences
809-
if global_batch_size is None:
810-
global_batch_size = seq_mask.sum()
811-
loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size # seq-mean
812-
elif loss_agg_mode == "seq-mean-token-mean":
813-
seq_mask = torch.sum(loss_mask, dim=-1) # per-sequence token count
814-
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / (seq_mask + 1e-8) # token-mean
815-
seq_mask = (seq_mask > 0).float() # exclude fully masked sequences
816-
if global_batch_size is None:
817-
global_batch_size = seq_mask.sum()
818-
loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size # seq-mean
819-
elif loss_agg_mode == "seq-mean-token-sum-norm":
820-
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)
821-
if loss_scale_factor is None:
822-
loss_scale_factor = loss_mask.shape[-1]
823-
loss = torch.sum(seq_losses) / loss_scale_factor
813+
elif loss_agg_mode.startswith("seq-mean"):
814+
# TODO: Correct and unify the denominator logic.
815+
if global_batch_size is not None:
816+
seq_denominator = global_batch_size * dp_size
817+
else: # The default logic which is only correct when the batch sizes are even.
818+
local_bsz = loss_mat.shape[0]
819+
seq_denominator = local_bsz
820+
821+
if loss_agg_mode.startswith("seq-mean-token-sum"):
822+
seq_losses = verl_F.masked_sum(loss_mat, loss_mask, axis=-1) # token-sum per sequence
823+
824+
if loss_agg_mode == "seq-mean-token-sum":
825+
pass # TODO: Add assertation.
826+
elif loss_agg_mode == "seq-mean-token-sum-norm":
827+
if loss_scale_factor is None:
828+
loss_scale_factor = loss_mask.shape[-1]
829+
seq_losses = seq_losses / loss_scale_factor
830+
else:
831+
raise ValueError(f"Invalid {loss_agg_mode=}")
832+
elif loss_agg_mode == "seq-mean-token-mean":
833+
token_counts = torch.sum(loss_mask, dim=-1) # per-sequence token count
834+
# token-mean per sequence
835+
seq_losses = verl_F.masked_sum(loss_mat, loss_mask, axis=-1) / (token_counts + 1e-8)
836+
else:
837+
raise ValueError(f"Invalid {loss_agg_mode=}")
838+
loss = torch.sum(seq_losses) / seq_denominator # seq-mean
824839
else:
825-
raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}")
840+
raise ValueError(f"Invalid {loss_agg_mode=}")
826841

827842
return loss
828843

verl/workers/actor/megatron_actor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,6 @@ def loss_func(output, data, meta_info):
479479

480480
entropy_coeff = self.config.entropy_coeff
481481
loss_agg_mode = self.config.loss_agg_mode
482-
483482
loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")
484483

485484
policy_loss_fn = get_policy_loss_fn(loss_mode)

0 commit comments

Comments
 (0)