Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions verl/trainer/ppo/core_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,7 @@ def agg_loss(
batch_num_tokens: Optional[int] = None,
global_batch_size: Optional[int] = None,
loss_scale_factor: Optional[int] = None,
exclude_fully_masked_seq: Optional[bool] = None,
):
"""
Aggregate the loss across global batch to ensure the loss is invariant to fsdp/megatron parallelism.
Expand All @@ -794,6 +795,10 @@ def agg_loss(
global_batch_size: global batch size
loss_scale_factor: scale factor for "seq-mean-token-sum-norm" mode. If None, uses loss_mask.shape[-1].
Set this to a constant value to ensure consistent normalization throughout training.
exclude_fully_masked_seq: whether to exclude fully masked sequences when computing the denominator
for seq-mean modes. If None (default), uses the total batch size as denominator.
If True, only counts non-fully-masked sequences in the denominator.
If False, uses total batch size as denominator.

Returns:
loss: `a scalar torch.Tensor`
Expand All @@ -805,16 +810,22 @@ def agg_loss(
loss = verl_F.masked_sum(loss_mat, loss_mask) / batch_num_tokens * dp_size
elif loss_agg_mode == "seq-mean-token-sum":
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum
seq_mask = (torch.sum(loss_mask, dim=-1) > 0).float() # exclude fully masked sequences
seq_mask = (torch.sum(loss_mask, dim=-1) > 0).float() # mask for non-fully-masked sequences
if global_batch_size is None:
global_batch_size = seq_mask.sum()
if exclude_fully_masked_seq:
global_batch_size = seq_mask.sum() * dp_size
else:
global_batch_size = loss_mat.shape[0] * dp_size
loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size # seq-mean
elif loss_agg_mode == "seq-mean-token-mean":
seq_mask = torch.sum(loss_mask, dim=-1) # per-sequence token count
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / (seq_mask + 1e-8) # token-mean
seq_mask = (seq_mask > 0).float() # exclude fully masked sequences
seq_mask = (seq_mask > 0).float() # mask for non-fully-masked sequences
if global_batch_size is None:
global_batch_size = seq_mask.sum()
if exclude_fully_masked_seq:
global_batch_size = seq_mask.sum() * dp_size
else:
global_batch_size = loss_mat.shape[0] * dp_size
loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size # seq-mean
elif loss_agg_mode == "seq-mean-token-sum-norm":
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)
Expand Down
19 changes: 17 additions & 2 deletions verl/workers/actor/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,11 @@ def update_policy(self, data: DataProto):
entropy_coeff = self.config.entropy_coeff
loss_agg_mode = self.config.loss_agg_mode

# Populate global_batch_info for loss aggregation
self.config.global_batch_info["dp_size"] = torch.distributed.get_world_size()
self.config.global_batch_info["global_batch_size"] = model_inputs.get("global_batch_size", None)
self.config.global_batch_info["exclude_fully_masked_seq"] = self.config.exclude_fully_masked_seq

calculate_entropy = self.config.calculate_entropy or (entropy_coeff != 0)

if self.config.use_dynamic_bsz:
Expand Down Expand Up @@ -516,7 +521,12 @@ def update_policy(self, data: DataProto):

policy_loss = pg_loss
if calculate_entropy and entropy is not None:
entropy_agg = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
entropy_agg = agg_loss(
loss_mat=entropy,
loss_mask=response_mask,
loss_agg_mode=loss_agg_mode,
**self.config.global_batch_info,
)
micro_batch_metrics["actor/entropy"] = entropy_agg.detach().item()
if entropy_coeff != 0:
policy_loss -= entropy_agg * entropy_coeff
Expand All @@ -527,7 +537,12 @@ def update_policy(self, data: DataProto):
kld = kl_penalty(
logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type
)
kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
kl_loss = agg_loss(
loss_mat=kld,
loss_mask=response_mask,
loss_agg_mode=loss_agg_mode,
**self.config.global_batch_info,
)

policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
micro_batch_metrics["actor/kl_loss"] = kl_loss.detach().item() * loss_scale_factor
Expand Down
19 changes: 17 additions & 2 deletions verl/workers/actor/megatron_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,11 @@ def loss_func(output, data, meta_info):
entropy_coeff = self.config.entropy_coeff
loss_agg_mode = self.config.loss_agg_mode

# Populate global_batch_info for loss aggregation
self.config.global_batch_info["dp_size"] = data.get("dp_size", 1)
self.config.global_batch_info["global_batch_size"] = data.get("global_batch_size", None)
self.config.global_batch_info["exclude_fully_masked_seq"] = self.config.exclude_fully_masked_seq

loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")

policy_loss_fn = get_policy_loss_fn(loss_mode)
Expand Down Expand Up @@ -518,7 +523,12 @@ def loss_func(output, data, meta_info):
if calculate_entropy:
entropy = output["entropy"][:, -response_length - 1 : -1].contiguous()
if not forward_only:
entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
entropy_loss = agg_loss(
loss_mat=entropy,
loss_mask=response_mask,
loss_agg_mode=loss_agg_mode,
**self.config.global_batch_info,
)
entropy_coeff = meta_info["entropy_coeff"]
policy_loss = pg_loss - entropy_coeff * entropy_loss
else:
Expand All @@ -531,7 +541,12 @@ def loss_func(output, data, meta_info):
ref_log_prob = data["ref_log_prob"]
# compute kl loss
kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type)
kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode)
kl_loss = agg_loss(
loss_mat=kld,
loss_mask=response_mask,
loss_agg_mode=self.config.loss_agg_mode,
**self.config.global_batch_info,
)

policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
metrics["actor/kl_loss"] = kl_loss.detach().item()
Expand Down
6 changes: 6 additions & 0 deletions verl/workers/config/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ class ActorConfig(BaseConfig):
loss_agg_mode (str): Loss aggregation mode. Options: 'token-mean', 'sample-mean'.
loss_scale_factor (Optional[int]): Scale factor for 'seq-mean-token-sum-norm' loss aggregation mode.
If None, uses response_length. Set to a constant to ensure consistent normalization.
exclude_fully_masked_seq (Optional[bool]): Whether to exclude fully masked sequences when computing
the denominator for seq-mean loss aggregation modes.
If None (default), uses total batch size as denominator.
If True, only counts non-fully-masked sequences in denominator.
If False, uses total batch size as denominator.
entropy_coeff (float): Entropy coefficient for regularization.
use_kl_loss (bool): Whether to use KL divergence loss.
use_torch_compile (bool): Whether to use torch.compile for optimization.
Expand Down Expand Up @@ -141,6 +146,7 @@ class ActorConfig(BaseConfig):
clip_ratio_c: float = 3.0
loss_agg_mode: str = "token-mean"
loss_scale_factor: Optional[int] = None
exclude_fully_masked_seq: Optional[bool] = None
entropy_coeff: float = 0
calculate_entropy: bool = False
use_kl_loss: bool = False
Expand Down
1 change: 1 addition & 0 deletions verl/workers/utils/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def ppo_loss(config: ActorConfig, model_output, data: TensorDict, dp_group=None)
config.global_batch_info["batch_num_tokens"] = data["batch_num_tokens"]
config.global_batch_info["global_batch_size"] = data["global_batch_size"]
config.global_batch_info["loss_scale_factor"] = config.loss_scale_factor
config.global_batch_info["exclude_fully_masked_seq"] = config.exclude_fully_masked_seq

metrics = {}

Expand Down
Loading