Skip to content

[skyrl-train][Examples] Support truncated importance sampling for StepWiseGenerator#570

Merged
SumanthRH merged 6 commits intoNovaSky-AI:mainfrom
SumanthRH:tis-multi-turn
Oct 31, 2025
Merged

[skyrl-train][Examples] Support truncated importance sampling for StepWiseGenerator#570
SumanthRH merged 6 commits intoNovaSky-AI:mainfrom
SumanthRH:tis-multi-turn

Conversation

@SumanthRH
Copy link
Member

@SumanthRH SumanthRH commented Oct 25, 2025

What does this PR do?

Adds support for truncated importance sampling (TIS ) with the step wise example. This means that TIS can now be used with multi-turn training

The implementation is mostly straightforward. One thing to note is that we append an EOS token in the agent loop when stop tokens are passed. This is a new token that was not generated by the model and thus logprobs are not available - thus this token is masked out.

Signed-off-by: SumanthRH <sumanthrh@anyscale.com>

if cfg.generator.backend == "sglang":
raise NotImplementedError("`trainer.algorithm.use_tis` doesn't support Sglang backend, please use vLLM")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This validation is too restrictive so I have removed it.

Now, we can use TIS with multi-turn generation in the step wise example

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for truncated importance sampling (TIS) to the StepWiseGenerator. The changes correctly handle fetching and processing logprobs for non-batched, multi-turn generation. A key aspect is the proper handling of manually appended EOS tokens by masking them from the loss calculation and accounting for their missing logprobs. The related validation logic has also been updated appropriately to allow this new functionality. My review includes one suggestion to simplify the configuration validation logic in StepWiseGenerator for better clarity and maintainability.

x
Signed-off-by: SumanthRH <sumanthrh@anyscale.com>
Copy link
Collaborator

@CharlieFRuan CharlieFRuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot and sorry for the delay. Left two nits!

Comment on lines +352 to +355
# if `added_eos` is `True`, then the EOS token was not generated and only added in the
# agent loop. For consistency with other entities like logprobs , we ignore it in the loss
# mask
loss_mask = [1] * len(output_ids) if not added_eos else [1] * (len(output_ids) - 1) + [0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Could you make this change to the base skyrl_gym_generator.py instead? So that we have the same behavior and only coded in one place.

stop_reason = engine_output["stop_reasons"][0]
response_logprobs = engine_output.get("response_logprobs", None)
if response_logprobs is not None:
response_logprobs = response_logprobs[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add an assert here assert len(output_ids) == len(response_logprobs)? Similarly for generate_batched

I see we do response_logprobs += [0] * (len(loss_mask) - len(response_logprobs)) later on here, and generate_batched does sample_logprobs = logprobs[i][: len(response_ids)]. I am worried that this hides potential mismatch of logprobs length and output ids length.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! actually I added the assert at the very end to catch issues. Should be good

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if you add it to the very end, it is after response_logprobs += [0] * (len(loss_mask) - len(response_logprobs)) right? Which hides potential mismatch before this. I was hoping to add a len(output_ids) == len(response_logprobs) right after parsing them. There can be possible mismatch due to re-tokenization

Copy link
Collaborator

@CharlieFRuan CharlieFRuan Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I guess it is trivial because we do TI/TO when we use logprobs. Though I'm worried this assumption could change in later PRs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can revisit this later - I don't think the assertion you mentioed should be here - it should be a unit test for the inference engine abstractions

x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
@SumanthRH
Copy link
Member Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Truncated Importance Sampling (TIS) in the StepWiseGenerator, which enables its use in multi-turn training scenarios. The changes are mostly centered around handling logprobs which are necessary for TIS.

The implementation correctly handles logprobs for tokens that are manually added (like EOS tokens) by masking them out in the loss and padding the logprobs array. The validation logic has also been updated to allow logprobs in non-batched mode, which is a necessary change for the StepWiseGenerator.

I've identified a couple of areas for improvement related to code consistency and clarity. One is an inconsistency in padding values for logprobs across different generator implementations. The other is a potentially confusing method override with a different signature in StepWiseGenerator. My detailed comments are below.

Overall, the changes are logical and enable the desired functionality.

Copy link
Collaborator

@CharlieFRuan CharlieFRuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@SumanthRH SumanthRH merged commit 9c3d1a4 into NovaSky-AI:main Oct 31, 2025
3 checks passed
li-boxuan pushed a commit to li-boxuan/SkyRL that referenced this pull request Nov 23, 2025
…epWiseGenerator` (NovaSky-AI#570)

# What does this PR do?

Adds support for truncated importance sampling (TIS ) with the step wise
example. This means that TIS can now be used with multi-turn training

The implementation is mostly straightforward. One thing to note is that
we append an EOS token in the agent loop when stop tokens are passed.
This is a new token that was not generated by the model and thus
logprobs are not available - thus this token is masked out.

---------

Signed-off-by: SumanthRH <sumanthrh@anyscale.com>
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
dzorlu pushed a commit to fleet-ai/SkyRL that referenced this pull request Feb 4, 2026
…epWiseGenerator` (NovaSky-AI#570)

# What does this PR do?

Adds support for truncated importance sampling (TIS ) with the step wise
example. This means that TIS can now be used with multi-turn training

The implementation is mostly straightforward. One thing to note is that
we append an EOS token in the agent loop when stop tokens are passed.
This is a new token that was not generated by the model and thus
logprobs are not available - thus this token is masked out.

---------

Signed-off-by: SumanthRH <sumanthrh@anyscale.com>
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants