This repository contains the code for the paper Zero-Shot Tokenizer Transfer. ZeTT frees language models from their tokenizer, allowing you to use any model with any tokenizer, with little or no extra training⚡
| Hypernetwork | ..for Model | Comments |
|---|---|---|
| benjamin/zett-hypernetwork-xlm-roberta-base | xlm-roberta-base | multilingual, 26 languages |
| benjamin/zett-hypernetwork-Mistral-7B-v0.1 | mistralai/Mistral-7B-v0.1 | English + Code |
| benjamin/zett-hypernetwork-multilingual-Mistral-7B-v0.1 | mistralai/Mistral-7B-v0.1 | multilingual, 26 languages |
| benjamin/zett-hypernetwork-TinyLlama-1.1B-intermediate-step-1431k-3T | TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T | English + Code |
| benjamin/zett-hypernetwork-Meta-Llama-3-8B-experimental | meta-llama/Meta-Llama-3-8B | experimental English + Code, seems to underperform on Code |
Requirements are in requirements.txt, This, for example, creates a working environment:
conda create -n zett Python=3.11
conda activate zett
pip install -r requirements.txt
pip install -U "jax[cuda12_pip]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # adjust based on your CUDA version
pip install -e .
Let's transfer XLM-RoBERTa to the GPT2 tokenizer.
git clone https://huggingface.co/benjamin/zett-hypernetwork-xlm-roberta-base
python3 scripts/transfer.py \
--target_model=FacebookAI/xlm-roberta-base \
--tokenizer_name=gpt2 \
--output=my-new-fancy-xlm-r \
--model_class=AutoModelForMaskedLM \
--lang_code=en \
--checkpoint_path=zett-hypernetwork-xlm-roberta-base \
--save_pt # otherwise saves only Flax weightsTada!
from transformers import AutoModelForMaskedLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("my-new-fancy-xlm-r")
model = AutoModelForMaskedLM.from_pretrained("my-new-fancy-xlm-r")
out = model(**tokenizer("Hello world!", return_tensors="pt"))..or Mistral-7B to the GPT-NeoX tokenizer:
git clone https://huggingface.co/benjamin/zett-hypernetwork-Mistral-7B-v0.1
# because Flax weights are not merged in the main branch, we need to specify the revision of a PR containing Flax weights
python3 scripts/transfer.py \
--target_model=mistralai/Mistral-7B-v0.1 \
--revision=refs/pr/95 \
--tokenizer_name=EleutherAI/gpt-neox-20b \
--output=my-new-fancy-mistral \
--model_class=AutoModelForCausalLM \
--checkpoint_path=zett-hypernetwork-Mistral-7B-v0.1 \
--save_pt # otherwise saves only Flax weightsfrom transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("my-new-fancy-mistral")
model = AutoModelForCausalLM.from_pretrained("my-new-fancy-mistral")
out = model(**tokenizer("Hello world!", return_tensors="pt"))Although the codebase is in Jax/Flax, there are Pytorch bindings for the model in ./hf_hypernet. You can use them as follows:
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
from zett.utils import get_surface_form_matrix
base_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
hypernet = AutoModel.from_pretrained("benjamin/zett-hypernetwork-Mistral-7B-v0.1", trust_remote_code=True)
source_embeddings = torch.concatenate([
base_model.get_input_embeddings().weight.data,
base_model.get_output_embeddings().weight.data,
], axis=1)
hn_tokenizer = AutoTokenizer.from_pretrained("benjamin/zett-hypernetwork-Mistral-7B-v0.1")
target_surface_forms = get_surface_form_matrix(
["Ġhello", "Ġworld"], # byte representation of the tokens to predict
maxlen=hypernet.config.hn_surface_maxlen,
tokenizer_to_use=hn_tokenizer,
)[0]
# the last output is the predicted bias in case the model uses a bias (e.g. XLM-R)
predicted_input_embeddings, predicted_output_embeddings, _ = hypernet(
torch.from_numpy(target_surface_forms),
source_embeddings=source_embeddings
)but transfer.py is currently not ported to PyTorch (PRs welcome!).
The script used to train the hypernetwork is train.py.
But first, you'll need to download and prepare the data via data/prepare.py and data/prepare_code.py.
You'll also need to install the Rust module in rust_utils (used to quickly sample tokenizers) via e.g. cd rust_utils && maturin develop --release.
Once finished, you can run training using the configs in configs/. For example:
python3 train.py configs/zeroshot/v7:tinyllama_en+code:lw=0.5_long.jsonto train a hypernetwork for TinyLlama on English and Code.
Use scripts/apply_to_ft.py to transfer the tokenizers of a fine-tuned model, given a base model with already transferred tokenizer. For example:
python3 scripts/apply_to_ft.py \
--output=transferred-chat-mistral \
--base_model_path=mistralai/Mistral-7B-v0.1 \
--ft_model_path=mistralai/Mistral-7B-Instruct-v0.1 \
--tokenizer_swapped_base_model_path=path-to-base-model-with-new-tokenizer \
--lambdas 0.5 \There are bash scripts in experiments/ to allow reproducing the main results from the paper.
Evaluation on code is still missing because we are using a fork of bigcode-evaluation-harness to fix some issues we encountered. They will be added soon.
Guide coming soon... (but feel free to dig into scripts/ in the meantime)
I prioritized releasing the code quickly instead of making it perfectly clean. There may still be remnants of my personal environment used to train the models and other non-niceties. I am in the process of cleaning this up. If you run into any problems or have any questions, please open an issue.
