Skip to content

[worker] fix: Fix missing rollout_log_probs argument in policy loss functions#3274

Merged
vermouth1992 merged 2 commits intoverl-project:mainfrom
kAIto47802:fix-policy-loss-fn
Aug 30, 2025
Merged

[worker] fix: Fix missing rollout_log_probs argument in policy loss functions#3274
vermouth1992 merged 2 commits intoverl-project:mainfrom
kAIto47802:fix-policy-loss-fn

Conversation

@kAIto47802
Copy link
Copy Markdown
Contributor

What does this PR do?

In the recent PR:

the file workers/actor/dp_actor.py was updated so that rollout_log_probs is passed to policy_loss_fn:

https://github.com/volcengine/verl/blob/38d23914ee512a125e00763fe3ddcc8df4319346/verl/workers/actor/dp_actor.py#L448-L456

In that PR, the "vanilla" policy loss function was modified to accept rollout_log_probs as an argument. However, other policy loss functions (e.g., "gspo") were not updated accordingly, which leads to an error such as:

TypeError: compute_policy_loss_gspo() got an unexpected keyword argument 'rollout_log_probs'

when setting config.policy_loss.loss_mode to one of these alternatives.

Therefore, in this PR, rollout_log_probs is also added as an argument to the other policy loss functions.

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Add code snippet or script demonstrating how to use this

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly identifies and fixes a TypeError by adding the rollout_log_probs argument to several policy loss functions, aligning their signatures with the PolicyLossFn type hint.

However, I've noticed that while the argument is added, it's not used in the function bodies of compute_policy_loss_gspo, compute_policy_loss_gpg, compute_policy_loss_clip_cov, compute_policy_loss_kl_cov, and compute_policy_loss_geo_mean. This is inconsistent with the compute_policy_loss_vanilla function, which uses this argument to implement Truncated Importance Sampling (TIS).

This inconsistency is a correctness issue, as a user enabling TIS would expect it to apply to all supported loss modes, but it would silently fail for these. I've left specific comments on each function to suggest implementing the TIS logic for completeness and correctness.

response_mask: torch.Tensor,
loss_agg_mode: str = "seq-mean-token-mean",
config: Optional[DictConfig | ActorConfig] = None,
rollout_log_probs: torch.Tensor | None = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The rollout_log_probs argument is added to the function signature to fix a TypeError, but it is not used within the function body. This is inconsistent with compute_policy_loss_vanilla, which uses this argument to apply Truncated Importance Sampling (TIS).

To ensure features like TIS work consistently across all policy loss functions, please implement the TIS logic in this function as well. Otherwise, TIS will be silently disabled for the gspo loss mode, which can be misleading and is a correctness issue.

You can adapt the logic from compute_policy_loss_vanilla:

if config.tis_imp_ratio_cap > 0 and rollout_log_probs is not None:
    # Apply truncated importance sampling -> https://fengyao.notion.site/off-policy-rl
    tis_imp_ratio = torch.exp(old_log_prob - rollout_log_probs)
    tis_imp_ratio = torch.clamp(tis_imp_ratio, max=config.tis_imp_ratio_cap)
    pg_losses = pg_losses * tis_imp_ratio

This should be applied to pg_losses before it's aggregated.

response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[DictConfig | AlgoConfig] = None,
rollout_log_probs: torch.Tensor | None = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The rollout_log_probs argument is added but remains unused. This is inconsistent with compute_policy_loss_vanilla, which uses it for Truncated Importance Sampling (TIS). For consistency and correctness, TIS should be implemented here as well. The gpg loss also computes pg_losses, so TIS can be applied before aggregation.

Additionally, the old_log_prob and config arguments are also unused. While this might be pre-existing, it's good practice to either use them or remove them if they are not needed for this loss function. Note that both are required to implement TIS.

response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[DictConfig | AlgoConfig] = None,
rollout_log_probs: torch.Tensor | None = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While adding rollout_log_probs fixes the TypeError, the argument is unused. This creates an inconsistency with compute_policy_loss_vanilla, which implements Truncated Importance Sampling (TIS) using this argument. Please apply the TIS logic here as well to ensure consistent behavior and prevent silent failures when TIS is enabled.

response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[DictConfig | AlgoConfig] = None,
rollout_log_probs: torch.Tensor | None = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The rollout_log_probs argument has been added but is not utilized within the function. To maintain consistency with compute_policy_loss_vanilla and ensure correctness, please implement Truncated Importance Sampling (TIS) using this argument, similar to the vanilla implementation. Without this, TIS will not work for the kl_cov loss mode.

response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[DictConfig | AlgoConfig] = None,
rollout_log_probs: torch.Tensor | None = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The rollout_log_probs argument is now present in the signature but is unused. This deviates from the compute_policy_loss_vanilla implementation where it's used for Truncated Importance Sampling (TIS). For correctness and feature parity across loss functions, please add the TIS logic to this function.

@vermouth1992 vermouth1992 merged commit a73b2ab into verl-project:main Aug 30, 2025
52 of 53 checks passed
yellowbee686 pushed a commit to yellowbee686/verl that referenced this pull request Aug 30, 2025
… functions (verl-project#3274)

<!--
> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review.
-->

In the recent PR:

- verl-project#2953,

the file `workers/actor/dp_actor.py` was updated so that
`rollout_log_probs` is passed to `policy_loss_fn`:

https://github.com/volcengine/verl/blob/38d23914ee512a125e00763fe3ddcc8df4319346/verl/workers/actor/dp_actor.py#L448-L456

In that PR, the "vanilla" policy loss function was modified to accept
`rollout_log_probs` as an argument. However, other policy loss functions
(e.g., "gspo") were not updated accordingly, which leads to an error
such as:

```
TypeError: compute_policy_loss_gspo() got an unexpected keyword argument 'rollout_log_probs'
```

when setting `config.policy_loss.loss_mode` to one of these
alternatives.

Therefore, in this PR, `rollout_log_probs` is also added as an argument
to the other policy loss functions.

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
```

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
cczitong123 pushed a commit to cczitong123/verl that referenced this pull request Sep 5, 2025
… functions (verl-project#3274)

### What does this PR do?

<!--
> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review.
-->

In the recent PR:

- verl-project#2953,

the file `workers/actor/dp_actor.py` was updated so that
`rollout_log_probs` is passed to `policy_loss_fn`:


https://github.com/volcengine/verl/blob/38d23914ee512a125e00763fe3ddcc8df4319346/verl/workers/actor/dp_actor.py#L448-L456

In that PR, the "vanilla" policy loss function was modified to accept
`rollout_log_probs` as an argument. However, other policy loss functions
(e.g., "gspo") were not updated accordingly, which leads to an error
such as:

```
TypeError: compute_policy_loss_gspo() got an unexpected keyword argument 'rollout_log_probs'
```

when setting `config.policy_loss.loss_mode` to one of these
alternatives.

Therefore, in this PR, `rollout_log_probs` is also added as an argument
to the other policy loss functions.



### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
yellowbee686 pushed a commit to yellowbee686/verl that referenced this pull request Sep 5, 2025
… functions (verl-project#3274)

<!--
> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review.
-->

In the recent PR:

- verl-project#2953,

the file `workers/actor/dp_actor.py` was updated so that
`rollout_log_probs` is passed to `policy_loss_fn`:

https://github.com/volcengine/verl/blob/38d23914ee512a125e00763fe3ddcc8df4319346/verl/workers/actor/dp_actor.py#L448-L456

In that PR, the "vanilla" policy loss function was modified to accept
`rollout_log_probs` as an argument. However, other policy loss functions
(e.g., "gspo") were not updated accordingly, which leads to an error
such as:

```
TypeError: compute_policy_loss_gspo() got an unexpected keyword argument 'rollout_log_probs'
```

when setting `config.policy_loss.loss_mode` to one of these
alternatives.

Therefore, in this PR, `rollout_log_probs` is also added as an argument
to the other policy loss functions.

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
```

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
DDVD233 pushed a commit to DDVD233/mirl that referenced this pull request Sep 5, 2025
… functions (verl-project#3274)

### What does this PR do?

<!--
> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review.
-->

In the recent PR:

- verl-project#2953,

the file `workers/actor/dp_actor.py` was updated so that
`rollout_log_probs` is passed to `policy_loss_fn`:


https://github.com/volcengine/verl/blob/38d23914ee512a125e00763fe3ddcc8df4319346/verl/workers/actor/dp_actor.py#L448-L456

In that PR, the "vanilla" policy loss function was modified to accept
`rollout_log_probs` as an argument. However, other policy loss functions
(e.g., "gspo") were not updated accordingly, which leads to an error
such as:

```
TypeError: compute_policy_loss_gspo() got an unexpected keyword argument 'rollout_log_probs'
```

when setting `config.policy_loss.loss_mode` to one of these
alternatives.

Therefore, in this PR, `rollout_log_probs` is also added as an argument
to the other policy loss functions.



### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
WncFht pushed a commit to WncFht/verl that referenced this pull request Oct 10, 2025
… functions (verl-project#3274)

### What does this PR do?

<!--
> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review.
-->

In the recent PR:

- verl-project#2953,

the file `workers/actor/dp_actor.py` was updated so that
`rollout_log_probs` is passed to `policy_loss_fn`:


https://github.com/volcengine/verl/blob/38d23914ee512a125e00763fe3ddcc8df4319346/verl/workers/actor/dp_actor.py#L448-L456

In that PR, the "vanilla" policy loss function was modified to accept
`rollout_log_probs` as an argument. However, other policy loss functions
(e.g., "gspo") were not updated accordingly, which leads to an error
such as:

```
TypeError: compute_policy_loss_gspo() got an unexpected keyword argument 'rollout_log_probs'
```

when setting `config.policy_loss.loss_mode` to one of these
alternatives.

Therefore, in this PR, `rollout_log_probs` is also added as an argument
to the other policy loss functions.



### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
masoudhashemi pushed a commit to masoudhashemi/verl that referenced this pull request Oct 19, 2025
… functions (verl-project#3274)

### What does this PR do?

<!--
> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review.
-->

In the recent PR:

- verl-project#2953,

the file `workers/actor/dp_actor.py` was updated so that
`rollout_log_probs` is passed to `policy_loss_fn`:


https://github.com/volcengine/verl/blob/38d23914ee512a125e00763fe3ddcc8df4319346/verl/workers/actor/dp_actor.py#L448-L456

In that PR, the "vanilla" policy loss function was modified to accept
`rollout_log_probs` as an argument. However, other policy loss functions
(e.g., "gspo") were not updated accordingly, which leads to an error
such as:

```
TypeError: compute_policy_loss_gspo() got an unexpected keyword argument 'rollout_log_probs'
```

when setting `config.policy_loss.loss_mode` to one of these
alternatives.

Therefore, in this PR, `rollout_log_probs` is also added as an argument
to the other policy loss functions.



### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
techkang pushed a commit to techkang/verl that referenced this pull request Oct 31, 2025
… functions (verl-project#3274)

### What does this PR do?

<!--
> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review.
-->

In the recent PR:

- verl-project#2953,

the file `workers/actor/dp_actor.py` was updated so that
`rollout_log_probs` is passed to `policy_loss_fn`:


https://github.com/volcengine/verl/blob/38d23914ee512a125e00763fe3ddcc8df4319346/verl/workers/actor/dp_actor.py#L448-L456

In that PR, the "vanilla" policy loss function was modified to accept
`rollout_log_probs` as an argument. However, other policy loss functions
(e.g., "gspo") were not updated accordingly, which leads to an error
such as:

```
TypeError: compute_policy_loss_gspo() got an unexpected keyword argument 'rollout_log_probs'
```

when setting `config.policy_loss.loss_mode` to one of these
alternatives.

Therefore, in this PR, `rollout_log_probs` is also added as an argument
to the other policy loss functions.



### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
chenjiaoAngel added a commit to chenjiaoAngel/verl that referenced this pull request Nov 14, 2025
… functions (verl-project#3274)

### What does this PR do?

<!--
> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review.
-->

In the recent PR:

- verl-project#2953,

the file `workers/actor/dp_actor.py` was updated so that
`rollout_log_probs` is passed to `policy_loss_fn`:


https://github.com/volcengine/verl/blob/38d23914ee512a125e00763fe3ddcc8df4319346/verl/workers/actor/dp_actor.py#L448-L456

In that PR, the "vanilla" policy loss function was modified to accept
`rollout_log_probs` as an argument. However, other policy loss functions
(e.g., "gspo") were not updated accordingly, which leads to an error
such as:

```
TypeError: compute_policy_loss_gspo() got an unexpected keyword argument 'rollout_log_probs'
```

when setting `config.policy_loss.loss_mode` to one of these
alternatives.

Therefore, in this PR, `rollout_log_probs` is also added as an argument
to the other policy loss functions.



### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
TimurTaepov pushed a commit to giorgossideris/verl that referenced this pull request Dec 20, 2025
… functions (verl-project#3274)

### What does this PR do?

<!--
> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review.
-->

In the recent PR:

- verl-project#2953,

the file `workers/actor/dp_actor.py` was updated so that
`rollout_log_probs` is passed to `policy_loss_fn`:


https://github.com/volcengine/verl/blob/38d23914ee512a125e00763fe3ddcc8df4319346/verl/workers/actor/dp_actor.py#L448-L456

In that PR, the "vanilla" policy loss function was modified to accept
`rollout_log_probs` as an argument. However, other policy loss functions
(e.g., "gspo") were not updated accordingly, which leads to an error
such as:

```
TypeError: compute_policy_loss_gspo() got an unexpected keyword argument 'rollout_log_probs'
```

when setting `config.policy_loss.loss_mode` to one of these
alternatives.

Therefore, in this PR, `rollout_log_probs` is also added as an argument
to the other policy loss functions.



### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
vyomakesh0728 added a commit to vyomakesh0728/verl that referenced this pull request Jan 22, 2026
… functions (verl-project#3274)

### What does this PR do?

<!--
> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review.
-->

In the recent PR:

- verl-project#2953,

the file `workers/actor/dp_actor.py` was updated so that
`rollout_log_probs` is passed to `policy_loss_fn`:


https://github.com/volcengine/verl/blob/38d23914ee512a125e00763fe3ddcc8df4319346/verl/workers/actor/dp_actor.py#L448-L456

In that PR, the "vanilla" policy loss function was modified to accept
`rollout_log_probs` as an argument. However, other policy loss functions
(e.g., "gspo") were not updated accordingly, which leads to an error
such as:

```
TypeError: compute_policy_loss_gspo() got an unexpected keyword argument 'rollout_log_probs'
```

when setting `config.policy_loss.loss_mode` to one of these
alternatives.

Therefore, in this PR, `rollout_log_probs` is also added as an argument
to the other policy loss functions.



### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants