Skip to content

Conversation

@Cyrilvallez
Copy link
Member

@Cyrilvallez Cyrilvallez commented Aug 8, 2025

What does this PR do?

As per the title. To avoid wasting memory for models with sliding window. As I don't want to reintroduce static hybrid caches by default to avoid all the pitfalls of automatic compilation, but don't want to waste that memory, this is definitely the way to go.

The only change that is needed is to pass the config to DynamicCache, to be able to parse sliding_window/layer_types. If we don't, then the behavior is exactly the same as before.

See the following figures for an illustration:

  • top: Mistral 7B, all layers are sliding, so the cache stops growing after reaching the window size of 4096
  • bottom: Gemma 2 9B, 1 out of 2 layers are sliding, so the Cache grows "sublinearly" after reaching the window size of 4096
Screenshot 2025-08-11 at 19 48 52 Screenshot 2025-08-11 at 19 55 49

Bonus:

  • Gpt OSS 20B: 1 out of 2 layers is sliding with window_size=128, so we basically fully divide memory requirements by 2 throughout the whole range (except super small input sizes < 128 of course) (it has lower absolute cache size because of only 24 layers and head_dim=64)
Screenshot 2025-08-12 at 14 15 47

Adding the benchmark script for posterity:

from transformers import AutoModelForCausalLM, DynamicCache
import torch
from tqdm import tqdm


model_name = "mistralai/Mistral-7B-v0.1"
# model_name = "google/gemma-2-9b-it"
# model_name = "openai/gpt-oss-20b"
device = 0

model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16)

input_sizes = torch.linspace(50, 8000, 50).tolist()

old_sizes = []
new_sizes = []
for size in tqdm(input_sizes):
    with torch.no_grad():
        input = torch.randint(1000, 3000, (1, int(size)), device=device)
        
        # initializing DynamicCache without config will use only full layers
        old_output = model(input, past_key_values=DynamicCache(), logits_to_keep=1)
        cache = old_output.past_key_values
        tot = sum([layer.keys.numel() * 2 * layer.keys.element_size() for layer in cache.layers])
        old_sizes.append(tot / 1024**3)

        # Initializing it with the config will infer and use the sliding window/hybrid structure
        new_output = model(input, past_key_values=DynamicCache(config=model.config), logits_to_keep=1)
        cache = new_output.past_key_values
        tot = sum([layer.keys.numel() * 2 * layer.keys.element_size() for layer in cache.layers])
        new_sizes.append(tot / 1024**3)

import matplotlib.pyplot as plt

plt.figure()
plt.plot(input_sizes, old_sizes, "r", label="before")
plt.plot(input_sizes, new_sizes, "b", label="now")
plt.xlabel("Cache size [tokens]")
plt.ylabel("Cache memory usage [GiB]")
plt.grid()
plt.legend()
plt.show()

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Cyrilvallez Cyrilvallez force-pushed the dynamic-sliding-hybrid branch from 534a6a4 to 41d55aa Compare August 11, 2025 08:31
@Cyrilvallez Cyrilvallez changed the title New DynamicSlidingWindow layer & caches New DynamicSlidingWindow layer & cache Aug 11, 2025
@Cyrilvallez Cyrilvallez changed the title New DynamicSlidingWindow layer & cache New DynamicSlidingWindowLayer & associated Cache Aug 11, 2025
@Cyrilvallez
Copy link
Member Author

Cyrilvallez commented Aug 11, 2025

All good now, slow tests on mistral, gemma2 and t5gemma are all similar to main (only a slight fa2 issue that surfaced on a slow test for mistral, but it's unrelated and solved by #40002)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect! This is long due especially for mistral models, but will also be effective for gpt oss (hybrid sliding I think, sliding window is 128 so would generate the graph for this one as well!)

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: arcee, aria, bitnet, cohere, cohere2, csm, deepseek_v2, deepseek_v3, diffllama, doge, dots1, emu3, ernie4_5, exaone4, fsmt, gemma2

@Cyrilvallez Cyrilvallez merged commit 41d1717 into main Aug 12, 2025
20 of 25 checks passed
@Cyrilvallez Cyrilvallez deleted the dynamic-sliding-hybrid branch August 12, 2025 12:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants