Skip to content

Possible Bug with KV Caching in Llama (original) model #25420

@maximkha

Description

@maximkha

System Info

transformers==4.31.0

  • huggingface_hub version: 0.15.1
  • Platform: Linux-5.15.0-78-generic-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Running in iPython ?: No
  • Running in notebook ?: No
  • Running in Google Colab ?: No
  • Token path ?: /u/k/h/khanov/.cache/huggingface/token
  • Has saved token ?: False
  • Configured git credential helpers:
  • FastAI: N/A
  • Tensorflow: N/A
  • Torch: 2.0.0
  • Jinja2: 3.0.3
  • Graphviz: N/A
  • Pydot: N/A
  • Pillow: 9.0.1
  • hf_transfer: N/A
  • gradio: N/A
  • numpy: 1.24.2
  • ENDPOINT: https://huggingface.co
  • HUGGINGFACE_HUB_CACHE: /u/k/h/khanov/.cache/huggingface/hub
  • HUGGINGFACE_ASSETS_CACHE: /u/k/h/khanov/.cache/huggingface/assets
  • HF_TOKEN_PATH: /u/k/h/khanov/.cache/huggingface/token
  • HF_HUB_OFFLINE: False
  • HF_HUB_DISABLE_TELEMETRY: False
  • HF_HUB_DISABLE_PROGRESS_BARS: None
  • HF_HUB_DISABLE_SYMLINKS_WARNING: False
  • HF_HUB_DISABLE_EXPERIMENTAL_WARNING: False
  • HF_HUB_DISABLE_IMPLICIT_TOKEN: False
  • HF_HUB_ENABLE_HF_TRANSFER: False

Who can help?

@ArthurZucker, @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I was working on a custom decoding method, however, I found a deviation from greedy search when using KV caching.

import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm

MODEL_PATH = "/nobackup-fast/khanov/llama-7b" # "huggyllama/llama-7b"
GEN_DEV = "cuda:0"

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16).to(GEN_DEV)

def get_input_ids(prompt: str) -> torch.Tensor:
    global model, tokenizer
    tokens = tokenizer(prompt, return_tensors="pt").input_ids.to(GEN_DEV)
    return tokens
def tokens_to_text(tokens: torch.Tensor):
    return tokenizer.batch_decode(tokens, skip_special_tokens=True)


PROMPT = "This is a " # this is just a test prompt

# greedy decoding without caching 
tokens = get_input_ids(PROMPT)
for _ in tqdm(range(40)):
    with torch.no_grad():
        mout = model(tokens)
    tokens = torch.hstack((tokens, torch.argmax(mout.logits[0, -1]).unsqueeze(0).unsqueeze(0)))
    
without_cache = tokens_to_text(tokens)[0]
print(f"{without_cache=}")

# greedy decoding WITH caching 
tokens = get_input_ids(PROMPT)
cached = None
for _ in tqdm(range(40)):
    with torch.no_grad():
        if cached is None:
            mout = model(tokens, output_hidden_states=True, use_cache=True)
            cached = mout.past_key_values
        else:
            mout = model(tokens, past_key_values=cached, use_cache=True, output_hidden_states=True)
            cached = mout.past_key_values
    tokens = torch.hstack((tokens, torch.argmax(mout.logits[0, -1]).unsqueeze(0).unsqueeze(0)))
    
with_cache = tokens_to_text(tokens)[0]
print(f"{with_cache=}")

# normal greedy search with HF Generate implementation
tokens = get_input_ids(PROMPT)
tokens = model.generate(tokens, num_return_sequences=1, max_new_tokens=40)
generate_output = tokens_to_text(tokens)[0]
print(f"{generate_output=}")

# this matches exactly
assert without_cache == generate_output

# this does not!
assert without_cache == with_cache

Expected behavior

I was expecting the results to not change when using the past_key_values kwarg, however, when passing past_key_values, the model assigned different logits to the tokens. This deviates from the model.generate behavior too. This is possibly related to #18809, and #21080.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions