-
Notifications
You must be signed in to change notification settings - Fork 31.5k
Adding handle_long_generation paramters for text-generation pipeline.
#14118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this!
| - :obj:`None` : default strategy where nothing in particular happens | ||
| - :obj:`"hole"`: Truncates left of input, and leaves a gap wide enough to let generation happen (might | ||
| truncate a lot of the prompt and not suitable when generation exceed the model capacity) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| truncate a lot of the prompt and not suitable when generation exceed the model capacity) | |
| truncate a lot of the prompt and not suitable when generation exceeds the model capacity) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a huge fan of the word "hole" - can we maybe call it "truncate_left" instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten Thanks for that remark, finding good names is important.
truncate_left is a bit misleading though because all methods will end up truncating left (don't think there's a way to avoid it really). The 3 options I have in mind:
hole: (Truncate left a chunk equivalent to desired new tokens, run generation as usual). you might delete quite a bit of context for your generation.correct_and_slow: (Don't truncate left at the start, but disable past_values and start generating 1 token at a time without using past and truncating left one by one as you generate new tokens)incorrect_and_fast: Same as correct_and_slow but keep using past_key_values to keep same performance. It's correct on models without position embeddings (not that many I believe), and will drift on other models (if it's only a few tokens difference might be negligible.
Do you think there are other ways to handle long generation ?
I agree hole might be not very clear or adapted.
truncate_chunk_left?truncate_block?
correct_and_slow could be... strict maybe (it conveys IMO the correctness of it and the fact that it's limiting in some form) of slow (if the other is fast)
incorrect_and_fast could be fast, approximate.
Other names could be more descriptive:
rotate_left_no_past
rotate_left_keep_past
maybe ?
BTW, the fast/slow option I don't intend to add soon, because they require more work within generate and don't seem as necessary for the moment, they just seem other strategies that could be valid depending on context.
| keep_length = self.tokenizer.model_max_length - new_tokens | ||
| if keep_length <= 0: | ||
| raise ValueError( | ||
| "We cannot use `hole` to handle this generation the number of desired tokens exceeds the models max length" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a dot missing? I can't make sense of this error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you want to generate 100 tokens and the model can only handle 50, then it's impossible to generate 100 with this method
| if model.config.__class__.__name__ == "RobertaConfig": | ||
| tokenizer.model_max_length = model.config.max_position_embeddings - 2 | ||
| elif ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this what you discusses with @patrickvonplaten or is it separate?
Not sure how it is linked to this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about Longformer and BART?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess longformer falls into the category of max_model_length > 1000 which does not have the test (since it would require creating a string so long to get out of the model bounds it's a bit too much for tests IMO).
Bart didn't seem to exhibit the same flaw. I checked the source code, it seems Bart hardcoded the +2 within it's embeddings directly allowing for the necessary capacity no matter the configs. https://github.com/huggingface/transformers/blob/master/src/transformers/models/bart/modeling_bart.py#L114
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sgugger yes it was. Actually issue lied in upstream model used for testing which incorrectly defined its tokenizer.model_max_length (since modified): https://huggingface.co/sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english
Here we still need this update because for pipelines tests:
- We take the tokenizer from ModelTester checkpoint
- We train a smaller tokenizer from it
- We create a random model from ModelTester.get_config (or ModelTester.get_pipeline_config)
That means the actual values between both can be out of sync (the case here) so we still need that override here.
| else: | ||
| new_tokens = generate_kwargs.get("max_length", self.model.config.max_length) - cur_len | ||
| if new_tokens < 0: | ||
| raise ValueError("We cannot infer how many new tokens are expected") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe change the error message to something like "input_string corresponding to {cur_len} tokens exceeds maximum generation length of model {max_len}}"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But also why can't we cut here? We can cut the input tokens the same way it's cut when max_new_tokens is defined no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If max_length > model_max_length a lot of warnings/errors are triggered.
So this would actually require users to use a different max_length than what they would send to generate (since generate doesn't allow max_length > model_max_length in TF at least).
Semantically, I think this becomes very confusing, even in tests, it would require to do clever calculations, involving knowing somehow the model capacity and the current text length in tokens (which you don't have beforehand). It also implies modifying a given argument (max_length) before sending it downstream to generate which IMHO is very bad design/ very confusing.
Using max_new_tokens here is the only way that's simple and declared intention unambiguously IMO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe to clarify
# model_max_len == 512
output = pipe("Some very long... text", max_length=750) # long text has 800 ids.Then what are we supposed to do ? There's no way to know how many tokens the user wanted originally. (he wanted -50 apparently)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the error to clarify.
LysandreJik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a nice addition, thank you for working on it @Narsil.
`max_new_tokens` to make it possible to understand user's intent. Otherwise, `max_length` == `tokenizer.model_max_length` < input_ids.shape[0].
529f3c5 to
08df6ae
Compare
What does this PR do?
Fixes #14033
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.