Skip to content

Require input_ids for repetition penalty#45389

Merged
Cyrilvallez merged 5 commits into
huggingface:mainfrom
ruben-aghayan:fix-repetition-penalty-inputs-embeds
May 13, 2026
Merged

Require input_ids for repetition penalty#45389
Cyrilvallez merged 5 commits into
huggingface:mainfrom
ruben-aghayan:fix-repetition-penalty-inputs-embeds

Conversation

@ruben-aghayan
Copy link
Copy Markdown
Contributor

@ruben-aghayan ruben-aghayan commented Apr 13, 2026

What does this PR do?

This PR warns when using repetition penalty or ngram repetition penalty in decoder models on input_embed without input_ids args.

Previously, users were able to call repetition penalty on generate calls with input_embeds args. Since they don't actually have tokens, the repetition penalty was not applied to the input args, but only to the generated tokens.
An equivalent call (ie tokens corresponding to those embeddings) would behave differently by applying the repetition penalty to the input tokens.
This change makes it so that the repetition penalty is not applied and a warning is shown.

Testing

pytest tests/generation -vv -k 'not test_text_streamer_decode_kwargs'
test_text_streamer_decode_kwargs was giving an unrelated failure

Code Agent Policy

  • I confirm that this is not a pure code agent PR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@gante

@afurm
Copy link
Copy Markdown

afurm commented Apr 13, 2026

Does prompt_input_ids.get() (or a follow-up check) need to handle the case where it's a list rather than a tensor? If input_ids is passed as a plain Python list, isinstance(..., torch.Tensor) would be False and this would raise the error even for valid input.

@ruben-aghayan ruben-aghayan force-pushed the fix-repetition-penalty-inputs-embeds branch from d1d35b7 to 3a4294c Compare April 13, 2026 05:46
@ruben-aghayan ruben-aghayan changed the title Guard repetition penalty for inputs_embeds Require input_ids for repetition penalty Apr 13, 2026
@ruben-aghayan
Copy link
Copy Markdown
Contributor Author

Does prompt_input_ids.get() (or a follow-up check) need to handle the case where it's a list rather than a tensor? If input_ids is passed as a plain Python list, isinstance(..., torch.Tensor) would be False and this would raise the error even for valid input.

Thank you for your comment!

Is List considered valid input? Generate args are all tensors. Admittedly, input_ids goes in kwargs so is not explicitly typed. But event today, such an input would fail e.g.

  from transformers import AutoTokenizer, AutoModelForCausalLM

  tok = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
  model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")

  ids = tok("Hello world", return_tensors="pt").input_ids[0].tolist()
  model.generate(input_ids=ids)

produces AttributeError: 'list' object has no attribute 'shape' since it's being treated as a tensor

@Rocketknight1
Copy link
Copy Markdown
Member

@remi-or @McPatate for generation/CB, but feel free to pass it on to someone else if you're not comfortable reviewing it!

@remi-or
Copy link
Copy Markdown
Collaborator

remi-or commented Apr 14, 2026

I am unfamiliar with encoder and the inputs embeds, so I would prefer if it can be passed on. If no one picks this up I will when I have some room!

@Rocketknight1
Copy link
Copy Markdown
Member

cc @Cyrilvallez for generation maybe, but if you're overloaded we might need to find someone to own generation code!

@Cyrilvallez
Copy link
Copy Markdown
Member

Hey @ruben-aghayan! We can indeed raise in such cases, but this code should live inside _get_logits_processor!

@ruben-aghayan ruben-aghayan marked this pull request as draft April 25, 2026 01:44
@ruben-aghayan ruben-aghayan force-pushed the fix-repetition-penalty-inputs-embeds branch 4 times, most recently from 1721159 to 08ac3d8 Compare April 25, 2026 04:11
@ruben-aghayan
Copy link
Copy Markdown
Contributor Author

Hey @ruben-aghayan! We can indeed raise in such cases, but this code should live inside _get_logits_processor!

thanks + done

I noticed EncoderRepetitionPenaltyLogitsProcessor above just warns so I switched to warning (this would have been enough for my use case). Also extended it to NoRepeatNGramLogitsProcessor

@ruben-aghayan ruben-aghayan marked this pull request as ready for review April 25, 2026 04:15
@ruben-aghayan
Copy link
Copy Markdown
Contributor Author

Comment thread src/transformers/generation/utils.py Outdated
Comment on lines +1092 to +1100
inputs_embeds = model_kwargs.get("inputs_embeds") if model_kwargs is not None else None
if inputs_embeds is not None and (input_ids_seq_length is None or input_ids_seq_length == 0):
warnings.warn(
"Passing `repetition_penalty` requires some form of `input_ids` to be passed to "
"`generate`, ignoring the argument.",
UserWarning,
)
else:
processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to skip, only warn that it will apply the repetition only on new tokens, vs applying it to the full sequence inclucing prompt

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @Cyrilvallez wdyt?

@ruben-aghayan ruben-aghayan force-pushed the fix-repetition-penalty-inputs-embeds branch from 0361926 to 299a2d8 Compare May 3, 2026 03:55
@ruben-aghayan ruben-aghayan marked this pull request as ready for review May 3, 2026 03:55
@Cyrilvallez
Copy link
Copy Markdown
Member

ALright, pushed some changes to simplify quite a bit - faster than reviewing. Thanks for the PR @ruben-aghayan!

@Cyrilvallez Cyrilvallez merged commit eee7039 into huggingface:main May 13, 2026
10 of 24 checks passed
@ruben-aghayan
Copy link
Copy Markdown
Contributor Author

ALright, pushed some changes to simplify quite a bit - faster than reviewing. Thanks for the PR @ruben-aghayan!

Thank you for the review and the changes!

jp1924 pushed a commit to jp1924/transformers that referenced this pull request May 18, 2026
* Guard repetition penalty for inputs_embeds

* Move repetition penalty guard to logits processor

* only warn that repetition processor wont apply to prompt when input embeds are provided but not input ids

* fix all the uneccesary if/else

* fix test

---------

Co-authored-by: Cyril Vallez <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants