Fix a regression in encoder-decoder generation cache initialization#46111
Conversation
Signed-off-by: root <[email protected]>
|
Move the fix to t5gemma modeling part, and fix device mismatch bug for failed case: |
vasqu
left a comment
There was a problem hiding this comment.
Can we also add a fast test for this? Or at least a faster test?
Left some smaller comments
| # We do not pass the config to the cross attn cache to avoid initializing SWA | ||
| # --> we use full attention between our cross attentions | ||
| past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache()) | ||
| elif ( |
There was a problem hiding this comment.
Imo, we should override similar to t5gemma2 instead see
| if encoder_attention_mask is None: | ||
| encoder_attention_mask = torch.ones( | ||
| encoder_hidden_states.shape[:2], device=inputs_embeds.device, dtype=torch.bool | ||
| ) |
There was a problem hiding this comment.
This seems weird to me, any reason this was actually needed? If yes, please add a comment as to why
There was a problem hiding this comment.
Well, it is somewhat like a WA here. As if we pass encoder_attention_mask to create_bidirectional_mask, it will hit early_exit in create_bidirectional_mask, and avoid the device mismatch issue afterwards. I think it should be a common issue for cross-attention mask, will fix the bug in masking_utils.py
Signed-off-by: Liu, Kaixuan <[email protected]>
Signed-off-by: Liu, Kaixuan <[email protected]>
Signed-off-by: Liu, Kaixuan <[email protected]>
|
@vasqu I have resolved the comments you left, can you help review it again? Thx! |
vasqu
left a comment
There was a problem hiding this comment.
LGTM, can we move the mask fix to a different PR. Other than that we can merge then
| # Use `inputs_embeds.device` to stay consistent with `_preprocess_mask_arguments`, which moves the 2D | ||
| # `attention_mask` to that device. In model parallel setups, `encoder_hidden_states` may live on a different | ||
| # device than `inputs_embeds` (e.g. cross-attention from a decoder to encoder states). | ||
| device = inputs_embeds.device |
There was a problem hiding this comment.
Can we move this to a separate PR but thanks for fixing, that is a good point
Signed-off-by: Liu, Kaixuan <[email protected]>
|
@vasqu Done. |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: t5gemma |
|
Thanks! Merging |
…uggingface#46111) * Fix a regression in encoder-decoder generation cache initialization Signed-off-by: Liu, Kaixuan <[email protected]> * move the fix to modeling part Signed-off-by: root <[email protected]> * update code Signed-off-by: Liu, Kaixuan <[email protected]> * Override `_prepare_cache_for_generation` Signed-off-by: Liu, Kaixuan <[email protected]> * add related fast test Signed-off-by: Liu, Kaixuan <[email protected]> * revert masking_utils.py change Signed-off-by: Liu, Kaixuan <[email protected]> --------- Signed-off-by: Liu, Kaixuan <[email protected]> Signed-off-by: root <[email protected]> Co-authored-by: root <[email protected]>
…uggingface#46111) * Fix a regression in encoder-decoder generation cache initialization Signed-off-by: Liu, Kaixuan <[email protected]> * move the fix to modeling part Signed-off-by: root <[email protected]> * update code Signed-off-by: Liu, Kaixuan <[email protected]> * Override `_prepare_cache_for_generation` Signed-off-by: Liu, Kaixuan <[email protected]> * add related fast test Signed-off-by: Liu, Kaixuan <[email protected]> * revert masking_utils.py change Signed-off-by: Liu, Kaixuan <[email protected]> --------- Signed-off-by: Liu, Kaixuan <[email protected]> Signed-off-by: root <[email protected]> Co-authored-by: root <[email protected]>
For encoder-decoder models,
generate()was passing the decoder config to both the self-attention cache and the cross-attention cache. This is incorrect for models like T5Gemma with decoder sliding-window layers: the cross-attention cache could inherit the decoder sliding-window structure and truncate encoder key/value states, causing FlashAttention generation to crash.This PR keeps the decoder config for the self-attention cache, but initializes the cross-attention cache without the decoder config so cross-attention remains full-length.
Tested with:
@vasqu @ArthurZucker @Cyrilvallez, pls help review, thx!