[skyrl-train][Examples] Support truncated importance sampling for StepWiseGenerator#570
Conversation
Signed-off-by: SumanthRH <sumanthrh@anyscale.com>
|
|
||
| if cfg.generator.backend == "sglang": | ||
| raise NotImplementedError("`trainer.algorithm.use_tis` doesn't support Sglang backend, please use vLLM") | ||
|
|
There was a problem hiding this comment.
This validation is too restrictive so I have removed it.
Now, we can use TIS with multi-turn generation in the step wise example
There was a problem hiding this comment.
Code Review
This pull request adds support for truncated importance sampling (TIS) to the StepWiseGenerator. The changes correctly handle fetching and processing logprobs for non-batched, multi-turn generation. A key aspect is the proper handling of manually appended EOS tokens by masking them from the loss calculation and accounting for their missing logprobs. The related validation logic has also been updated appropriately to allow this new functionality. My review includes one suggestion to simplify the configuration validation logic in StepWiseGenerator for better clarity and maintainability.
CharlieFRuan
left a comment
There was a problem hiding this comment.
Thanks a lot and sorry for the delay. Left two nits!
| # if `added_eos` is `True`, then the EOS token was not generated and only added in the | ||
| # agent loop. For consistency with other entities like logprobs , we ignore it in the loss | ||
| # mask | ||
| loss_mask = [1] * len(output_ids) if not added_eos else [1] * (len(output_ids) - 1) + [0] |
There was a problem hiding this comment.
Good point! Could you make this change to the base skyrl_gym_generator.py instead? So that we have the same behavior and only coded in one place.
| stop_reason = engine_output["stop_reasons"][0] | ||
| response_logprobs = engine_output.get("response_logprobs", None) | ||
| if response_logprobs is not None: | ||
| response_logprobs = response_logprobs[0] |
There was a problem hiding this comment.
Can we add an assert here assert len(output_ids) == len(response_logprobs)? Similarly for generate_batched
I see we do response_logprobs += [0] * (len(loss_mask) - len(response_logprobs)) later on here, and generate_batched does sample_logprobs = logprobs[i][: len(response_ids)]. I am worried that this hides potential mismatch of logprobs length and output ids length.
There was a problem hiding this comment.
Done! actually I added the assert at the very end to catch issues. Should be good
There was a problem hiding this comment.
But if you add it to the very end, it is after response_logprobs += [0] * (len(loss_mask) - len(response_logprobs)) right? Which hides potential mismatch before this. I was hoping to add a len(output_ids) == len(response_logprobs) right after parsing them. There can be possible mismatch due to re-tokenization
There was a problem hiding this comment.
But I guess it is trivial because we do TI/TO when we use logprobs. Though I'm worried this assumption could change in later PRs
There was a problem hiding this comment.
I think we can revisit this later - I don't think the assertion you mentioed should be here - it should be a unit test for the inference engine abstractions
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request adds support for Truncated Importance Sampling (TIS) in the StepWiseGenerator, which enables its use in multi-turn training scenarios. The changes are mostly centered around handling logprobs which are necessary for TIS.
The implementation correctly handles logprobs for tokens that are manually added (like EOS tokens) by masking them out in the loss and padding the logprobs array. The validation logic has also been updated to allow logprobs in non-batched mode, which is a necessary change for the StepWiseGenerator.
I've identified a couple of areas for improvement related to code consistency and clarity. One is an inconsistency in padding values for logprobs across different generator implementations. The other is a potentially confusing method override with a different signature in StepWiseGenerator. My detailed comments are below.
Overall, the changes are logical and enable the desired functionality.
…epWiseGenerator` (NovaSky-AI#570) # What does this PR do? Adds support for truncated importance sampling (TIS ) with the step wise example. This means that TIS can now be used with multi-turn training The implementation is mostly straightforward. One thing to note is that we append an EOS token in the agent loop when stop tokens are passed. This is a new token that was not generated by the model and thus logprobs are not available - thus this token is masked out. --------- Signed-off-by: SumanthRH <sumanthrh@anyscale.com> Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
…epWiseGenerator` (NovaSky-AI#570) # What does this PR do? Adds support for truncated importance sampling (TIS ) with the step wise example. This means that TIS can now be used with multi-turn training The implementation is mostly straightforward. One thing to note is that we append an EOS token in the agent loop when stop tokens are passed. This is a new token that was not generated by the model and thus logprobs are not available - thus this token is masked out. --------- Signed-off-by: SumanthRH <sumanthrh@anyscale.com> Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
What does this PR do?
Adds support for truncated importance sampling (TIS ) with the step wise example. This means that TIS can now be used with multi-turn training
The implementation is mostly straightforward. One thing to note is that we append an EOS token in the agent loop when stop tokens are passed. This is a new token that was not generated by the model and thus logprobs are not available - thus this token is masked out.