-
Notifications
You must be signed in to change notification settings - Fork 31.7k
Allow Exclusion of Input IDs from RepetitionPenaltyLogitsProcessor #37625
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
6b60d87 to
94ab64e
Compare
|
Hi Alex 👋 Thank you for opening this PR! This is the first time I'm reading about this feature request. As such, there are a few workarounds I'd like to ask you to try before adding your proposal -- minimizing flags greatly improves the UX of the library in the long run 🤗 If none of them solves your problem (or if they have unwanted secondary effects), let me know!
|
|
Hi @gante, thanks a lot for the quick response, I really appreciate the suggestions! That all makes total sense and I agree. I'm not a fan of the flag either, but thought it might be nice to avoid changing the default behavior. Do you have any thoughts on keeping the change to the This might be simpler, because then if people do end up hitting this in the future, there is a clear workaround by passing it as a custom logit processor, and it avoids touching the generate api 🙂 from transformers import RepetitionPenaltyLogitsProcessor
...
logits_processor = [
RepetitionPenaltyLogitsProcessor(
penalty=3.0,
input_ids_seq_length=model_inputs["input_ids"].shape[-1],
)
]
model_outputs = model.generate(
**model_inputs,
...
logits_processor=logits_processor,
) |
|
Yeah, I'd be happy to add the flag to the processor, but not to the generation config -- it would indeed be the best of both worlds 🤗 |
gante
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks mostly good to me, added a few documentation/signature nits 🤗
| penalty (`float`): | ||
| The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 penalizes previously generated | ||
| tokens. Between 0.0 and 1.0 rewards previously generated tokens. | ||
| input_ids_seq_length (`int`, *optional*, defaults to 0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's rename it to something like prompt_ignore_length, so that it is clear that it modifies the original behavior.
input_ids_seq_length, at a first glance, looks like an input we should provide to get the correct behavior, as opposed to a modified behavior.
Let's also document the flag accordingly (it's an optional modifier), and explain that if it is to be used the class must be passed as a custom processor to generate. Ideally, with an example in the doctest below :D
gante
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for iterating 🤗
|
Awesome, thanks a lot for the quick reviews! 🙂 |
…uggingface#37625) * Allow exclusion of input IDs for repetition penalty * Add logit proc tests for rep penalty exclusion * Expose rep pen flag through generate * Only slice if needed * keep current rep pen default behavior * Revert exposing reppen changes through generate * Fix test arg * Update src/transformers/generation/logits_process.py Co-authored-by: Joao Gante <[email protected]> * Rename to rep penalty kwarg * Add custom repetition penalty processor example * Validate prompt_ignore_length --------- Co-authored-by: Joao Gante <[email protected]>
This PR adds a flag for allowing the exclusion of input IDs when using the
RepetitionPenaltyLogitsProcessor- currently there are some workarounds that may be used for language models, e.g., passing input embeddings instead of input ids, however, these workarounds are a bit trickier with multimodal models, since we generally pass the input ids first to create the merged multimodal embeddings in the model.Adding such a flag would be really helpful for models like granite speech, which @avihu111 and I had recently added support for. In experiments, the model works very well with a super high repetition penalty on just the newly generated token IDs, but performance degrades severely when the input IDs are included.
By default, input IDs are included to avoid changing existing default behaviors.
Fixes #36642
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
Probably @gante @zucchini-nlp @eustlb would have the most context!