Skip to content

Conversation

@alex-jw-brooks
Copy link
Contributor

@alex-jw-brooks alex-jw-brooks commented Apr 19, 2025

This PR adds a flag for allowing the exclusion of input IDs when using the RepetitionPenaltyLogitsProcessor - currently there are some workarounds that may be used for language models, e.g., passing input embeddings instead of input ids, however, these workarounds are a bit trickier with multimodal models, since we generally pass the input ids first to create the merged multimodal embeddings in the model.

Adding such a flag would be really helpful for models like granite speech, which @avihu111 and I had recently added support for. In experiments, the model works very well with a super high repetition penalty on just the newly generated token IDs, but performance degrades severely when the input IDs are included.

By default, input IDs are included to avoid changing existing default behaviors.

Fixes #36642

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Probably @gante @zucchini-nlp @eustlb would have the most context!

@github-actions github-actions bot marked this pull request as draft April 19, 2025 07:22
@github-actions
Copy link
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@alex-jw-brooks alex-jw-brooks marked this pull request as ready for review April 19, 2025 07:37
@github-actions github-actions bot requested a review from gante April 19, 2025 07:37
@gante
Copy link
Contributor

gante commented Apr 19, 2025

Hi Alex 👋 Thank you for opening this PR!

This is the first time I'm reading about this feature request. As such, there are a few workarounds I'd like to ask you to try before adding your proposal -- minimizing flags greatly improves the UX of the library in the long run 🤗 If none of them solves your problem (or if they have unwanted secondary effects), let me know!


  1. If we combine EncoderRepetitionPenaltyLogitsProcessor (modifier on the prompt tokens) with RepetitionPenaltyLogitsProcessor (modifier on all tokens) with the right penalty values, we can cancel out the modifier applied to the prompt. The difference is what happens if a token that is in the prompt gets generated: with the existing solution, that token will never get a modifier, while with your solution it gets a modifier.
  2. If we run a forward pass with the prompt to extract the KV cache, and then call generate passing the KV cache without passing the prompt tokens, the penalty only gets applied to new tokens. The caveat here is that you need to generate the next token from the prompt manually after the forward pass, and feed it to generate as input (so generate receives the cache + 1st generated token). You'll also need to hand-craft the attention mask for generate. It's possible there are bugs. If this works and it's not too complex on your side, I can open a PR to automate the creation of the attention mask when partial input tokens + KV cache are passed to generate.

@alex-jw-brooks
Copy link
Contributor Author

alex-jw-brooks commented Apr 19, 2025

Hi @gante, thanks a lot for the quick response, I really appreciate the suggestions! That all makes total sense and I agree. I'm not a fan of the flag either, but thought it might be nice to avoid changing the default behavior.

Do you have any thoughts on keeping the change to the RepetitionPenaltyLogitsProcessor, but just avoiding exposing it through.generate, like the current state of this PR?

This might be simpler, because then if people do end up hitting this in the future, there is a clear workaround by passing it as a custom logit processor, and it avoids touching the generate api 🙂

from transformers import RepetitionPenaltyLogitsProcessor

...

logits_processor = [
    RepetitionPenaltyLogitsProcessor(
        penalty=3.0,
        input_ids_seq_length=model_inputs["input_ids"].shape[-1],
    )
]

model_outputs = model.generate(
    **model_inputs,
    ...
    logits_processor=logits_processor,
)

@gante
Copy link
Contributor

gante commented Apr 21, 2025

Yeah, I'd be happy to add the flag to the processor, but not to the generation config -- it would indeed be the best of both worlds 🤗

Copy link
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks mostly good to me, added a few documentation/signature nits 🤗

penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 penalizes previously generated
tokens. Between 0.0 and 1.0 rewards previously generated tokens.
input_ids_seq_length (`int`, *optional*, defaults to 0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's rename it to something like prompt_ignore_length, so that it is clear that it modifies the original behavior.

input_ids_seq_length, at a first glance, looks like an input we should provide to get the correct behavior, as opposed to a modified behavior.

Let's also document the flag accordingly (it's an optional modifier), and explain that if it is to be used the class must be passed as a custom processor to generate. Ideally, with an example in the doctest below :D

Copy link
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for iterating 🤗

@alex-jw-brooks
Copy link
Contributor Author

Awesome, thanks a lot for the quick reviews! 🙂

@gante gante merged commit a42ba80 into huggingface:main Apr 21, 2025
18 checks passed
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
…uggingface#37625)

* Allow exclusion of input IDs for repetition penalty

* Add logit proc tests for rep penalty exclusion

* Expose rep pen flag through generate

* Only slice if needed

* keep current rep pen default behavior

* Revert exposing reppen changes through generate

* Fix test arg

* Update src/transformers/generation/logits_process.py

Co-authored-by: Joao Gante <[email protected]>

* Rename to rep penalty kwarg

* Add custom repetition penalty processor example

* Validate prompt_ignore_length

---------

Co-authored-by: Joao Gante <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Is it correct that the repetition penalty is applied to the input_ids encompassing all inputs and outputs, rather than solely on the generated tokens?

2 participants