-
Notifications
You must be signed in to change notification settings - Fork 31.4k
Closed
Description
System Info
transformers==4.31.0
- huggingface_hub version: 0.15.1
- Platform: Linux-5.15.0-78-generic-x86_64-with-glibc2.35
- Python version: 3.10.12
- Running in iPython ?: No
- Running in notebook ?: No
- Running in Google Colab ?: No
- Token path ?: /u/k/h/khanov/.cache/huggingface/token
- Has saved token ?: False
- Configured git credential helpers:
- FastAI: N/A
- Tensorflow: N/A
- Torch: 2.0.0
- Jinja2: 3.0.3
- Graphviz: N/A
- Pydot: N/A
- Pillow: 9.0.1
- hf_transfer: N/A
- gradio: N/A
- numpy: 1.24.2
- ENDPOINT: https://huggingface.co
- HUGGINGFACE_HUB_CACHE: /u/k/h/khanov/.cache/huggingface/hub
- HUGGINGFACE_ASSETS_CACHE: /u/k/h/khanov/.cache/huggingface/assets
- HF_TOKEN_PATH: /u/k/h/khanov/.cache/huggingface/token
- HF_HUB_OFFLINE: False
- HF_HUB_DISABLE_TELEMETRY: False
- HF_HUB_DISABLE_PROGRESS_BARS: None
- HF_HUB_DISABLE_SYMLINKS_WARNING: False
- HF_HUB_DISABLE_EXPERIMENTAL_WARNING: False
- HF_HUB_DISABLE_IMPLICIT_TOKEN: False
- HF_HUB_ENABLE_HF_TRANSFER: False
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
I was working on a custom decoding method, however, I found a deviation from greedy search when using KV caching.
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
MODEL_PATH = "/nobackup-fast/khanov/llama-7b" # "huggyllama/llama-7b"
GEN_DEV = "cuda:0"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16).to(GEN_DEV)
def get_input_ids(prompt: str) -> torch.Tensor:
global model, tokenizer
tokens = tokenizer(prompt, return_tensors="pt").input_ids.to(GEN_DEV)
return tokens
def tokens_to_text(tokens: torch.Tensor):
return tokenizer.batch_decode(tokens, skip_special_tokens=True)
PROMPT = "This is a " # this is just a test prompt
# greedy decoding without caching
tokens = get_input_ids(PROMPT)
for _ in tqdm(range(40)):
with torch.no_grad():
mout = model(tokens)
tokens = torch.hstack((tokens, torch.argmax(mout.logits[0, -1]).unsqueeze(0).unsqueeze(0)))
without_cache = tokens_to_text(tokens)[0]
print(f"{without_cache=}")
# greedy decoding WITH caching
tokens = get_input_ids(PROMPT)
cached = None
for _ in tqdm(range(40)):
with torch.no_grad():
if cached is None:
mout = model(tokens, output_hidden_states=True, use_cache=True)
cached = mout.past_key_values
else:
mout = model(tokens, past_key_values=cached, use_cache=True, output_hidden_states=True)
cached = mout.past_key_values
tokens = torch.hstack((tokens, torch.argmax(mout.logits[0, -1]).unsqueeze(0).unsqueeze(0)))
with_cache = tokens_to_text(tokens)[0]
print(f"{with_cache=}")
# normal greedy search with HF Generate implementation
tokens = get_input_ids(PROMPT)
tokens = model.generate(tokens, num_return_sequences=1, max_new_tokens=40)
generate_output = tokens_to_text(tokens)[0]
print(f"{generate_output=}")
# this matches exactly
assert without_cache == generate_output
# this does not!
assert without_cache == with_cacheExpected behavior
I was expecting the results to not change when using the past_key_values kwarg, however, when passing past_key_values, the model assigned different logits to the tokens. This deviates from the model.generate behavior too. This is possibly related to #18809, and #21080.
amitlevy
Metadata
Metadata
Assignees
Labels
No labels