Skip to content

Fix a regression in encoder-decoder generation cache initialization#46111

Merged
vasqu merged 10 commits into
huggingface:mainfrom
kaixuanliu:cross-aatn-cache
May 26, 2026
Merged

Fix a regression in encoder-decoder generation cache initialization#46111
vasqu merged 10 commits into
huggingface:mainfrom
kaixuanliu:cross-aatn-cache

Conversation

@kaixuanliu
Copy link
Copy Markdown
Contributor

For encoder-decoder models, generate() was passing the decoder config to both the self-attention cache and the cross-attention cache. This is incorrect for models like T5Gemma with decoder sliding-window layers: the cross-attention cache could inherit the decoder sliding-window structure and truncate encoder key/value states, causing FlashAttention generation to crash.

This PR keeps the decoder config for the self-attention cache, but initializes the cross-attention cache without the decoder config so cross-attention remains full-length.

Tested with:

pytest -q -rA tests/models/t5gemma/test_modeling_t5gemma.py::T5GemmaModelTest::test_generate_beyond_sliding_window_with_flash_attn

@vasqu @ArthurZucker @Cyrilvallez, pls help review, thx!

@kaixuanliu kaixuanliu marked this pull request as draft May 25, 2026 08:22
@kaixuanliu kaixuanliu marked this pull request as ready for review May 25, 2026 12:29
@kaixuanliu
Copy link
Copy Markdown
Contributor Author

kaixuanliu commented May 25, 2026

Move the fix to t5gemma modeling part, and fix device mismatch bug for failed case:
tests/models/t5gemma/test_modeling_t5gemma.py::T5GemmaModelTest::test_model_parallel_beam_search

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add a fast test for this? Or at least a faster test?

Left some smaller comments

# We do not pass the config to the cross attn cache to avoid initializing SWA
# --> we use full attention between our cross attentions
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache())
elif (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo, we should override similar to t5gemma2 instead see

def _prepare_cache_for_generation(

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines +650 to +653
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(
encoder_hidden_states.shape[:2], device=inputs_embeds.device, dtype=torch.bool
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems weird to me, any reason this was actually needed? If yes, please add a comment as to why

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it is somewhat like a WA here. As if we pass encoder_attention_mask to create_bidirectional_mask, it will hit early_exit in create_bidirectional_mask, and avoid the device mismatch issue afterwards. I think it should be a common issue for cross-attention mask, will fix the bug in masking_utils.py

@kaixuanliu
Copy link
Copy Markdown
Contributor Author

@vasqu I have resolved the comments you left, can you help review it again? Thx!

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, can we move the mask fix to a different PR. Other than that we can merge then

Comment thread src/transformers/masking_utils.py Outdated
# Use `inputs_embeds.device` to stay consistent with `_preprocess_mask_arguments`, which moves the 2D
# `attention_mask` to that device. In model parallel setups, `encoder_hidden_states` may live on a different
# device than `inputs_embeds` (e.g. cross-attention from a decoder to encoder states).
device = inputs_embeds.device
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this to a separate PR but thanks for fixing, that is a good point

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move to #46221.

@kaixuanliu
Copy link
Copy Markdown
Contributor Author

@vasqu Done.

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: t5gemma

@vasqu vasqu enabled auto-merge May 26, 2026 15:14
@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented May 26, 2026

Thanks! Merging

@vasqu vasqu added this pull request to the merge queue May 26, 2026
Merged via the queue into huggingface:main with commit 90e3c4f May 26, 2026
23 checks passed
@kaixuanliu kaixuanliu deleted the cross-aatn-cache branch May 27, 2026 02:07
yuchenxie4645 pushed a commit to yuchenxie4645/transformers that referenced this pull request May 28, 2026
…uggingface#46111)

* Fix a regression in encoder-decoder generation cache initialization

Signed-off-by: Liu, Kaixuan <[email protected]>

* move the fix to modeling part

Signed-off-by: root <[email protected]>

* update code

Signed-off-by: Liu, Kaixuan <[email protected]>

* Override `_prepare_cache_for_generation`

Signed-off-by: Liu, Kaixuan <[email protected]>

* add related fast test

Signed-off-by: Liu, Kaixuan <[email protected]>

* revert masking_utils.py change

Signed-off-by: Liu, Kaixuan <[email protected]>

---------

Signed-off-by: Liu, Kaixuan <[email protected]>
Signed-off-by: root <[email protected]>
Co-authored-by: root <[email protected]>
kashif pushed a commit to kashif/transformers that referenced this pull request Jun 1, 2026
…uggingface#46111)

* Fix a regression in encoder-decoder generation cache initialization

Signed-off-by: Liu, Kaixuan <[email protected]>

* move the fix to modeling part

Signed-off-by: root <[email protected]>

* update code

Signed-off-by: Liu, Kaixuan <[email protected]>

* Override `_prepare_cache_for_generation`

Signed-off-by: Liu, Kaixuan <[email protected]>

* add related fast test

Signed-off-by: Liu, Kaixuan <[email protected]>

* revert masking_utils.py change

Signed-off-by: Liu, Kaixuan <[email protected]>

---------

Signed-off-by: Liu, Kaixuan <[email protected]>
Signed-off-by: root <[email protected]>
Co-authored-by: root <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants