A joint project by: Sean McLeish, Ang Li, John Kirchenbauer, Dayal Singh Kalra, Brian R. Bartoldson, Bhavya Kailkhura, Avi Schwarzschild, Jonas Geiping, Tom Goldstein, Micah Goldblum
To cite our work, please use this bibtex.
@article{mcleish2025teaching,
title={Teaching Pretrained Language Models to Think Deeper with Retrofitted Recurrence},
author={Sean McLeish and Ang Li and John Kirchenbauer and Dayal Singh Kalra and Brian R. Bartoldson and Bhavya Kailkhura and Avi Schwarzschild and Jonas Geiping and Tom Goldstein and Micah Goldblum},
journal={arXiv preprint arXiv:2511.07384},
year={2025}
}
We developed in Python 3.11, to install run:
git clone [email protected]:mcleish7/retrofitting-recurrence.git
cd retrofitting-recurrence
pip install -r requirements.txt
- To download run:
python utils/download_ds.py --dataset_path YOUR_DATASET_LOCATIONNote: you may need to request permission on HuggingFace to access some of the Nvidia datasets. - To tokenize run:
python preprocess_data_packing.py --out_path="llama_1b_packed_nemotron_cc_math_v1_4plus_wrapped_packing" --dataset_location="datasets/Nemotron-CC-Math-v1-4plus" --cache_path=YOUR_CACHE_PATH --save_path=YOUR_SAVE_PATHYou can use thetokenizer_nameflag to control the tokenizer being used. - To save to parquet:
python utils/to_parquet.py --dataset_path YOUR_TOKENIZED_DATASET_LOCATION --dataset_save_dir $PROCESSED_DATA_PATH/YOUR_PARQUET_SAVE_LOCATION
As an example, we upload our Llama-3 tokenized parquet FineWeb-Edu-350B dataset here.
To obtain the data mix used in Figure 8, we run mix_datasets.py. We split our datasets into shards to process and tokenize them and combine some shards in mix_datasets.py, if your workflow is different, we take approximately 12.8M rows from each split. We note due to the Nemotron licence, we cannot openly upload our exact dataset; please open an issue if there is any trouble here.
We provide conversion scripts for TinyLlama, Llama and OLMo in convert_pretrained_model and provide untrained (outputs from the conversion script) models in our collection.
- For TinyLlama/Llama use convert_pretrained_model/convert_llama.py
- For OLMo use convert_pretrained_model/convert_olmo.py
There are multiple steps, we are going to use multiple files to ensure that the converted model is as faithful to the original as possible:
- Download the model using utils/download_to_local.py. Also, download
tomg-group-umd/huginn-0125atrevision="972cea674c2f4ea37da6777ece1a0c9895c9998b"intoconvert_pretrained_model/models/huginn-0125. - Add the
looped_{model}.pyfile into the downloaded snapshot dir. - Run
convert_{model}.pycode (read the comment at the top of the main function for how to select different model shapes), this will error but the dir with the new weights will be created. - In the newly created dir, overwrite the contents of
raven_modeling_minimal.pyfile with the contents ofraven_modeling_minimal_compare_{model}.pyfile. - Rerun
convert_{model}.py, this time you should see a lot ofTrue, 0.000printed meaning that the hidden states all match exactly. If not there is something wrong, reread all variables changed inconvert_{model}.py, fix and retry. - Overwrite the contents of
raven_modeling_minimal.pyfile with the contents ofraven_modeling_minimal_{model}.pyfile. This is slightly different to compare in that it returns less information and uses the linear adapter.
NOTE: the model conversion code is built to work with transformers==4.51.0 due to a KV-Cache breaking change in future versions.
WARNING: We only tested the parts of the modelling files used in this repo (e.g. forward(), generate()), however leave in all functions from the Huginn-0125 model. Please use untested features with caution.
Example commands are in the shells/ directory, organised by model. We use the same $PROCESSED_DATA_PATH temporary variable as used in the datasets section above, make sure to overwrite this to your specific path.
We use the train.py to train, this is based on the Huginn finetuning script but with extra features, such as parquet data loading and extra optimizers.
Note the save_n_mins_before_timeout flag is designed to work on flux scheduling systems only.
Example commands are in the shells/eval.sh file using lm_eval.
We added " Let's think step by step." to the gsm8k-cot prompt, our yaml is in eval_yamls/gsm8k-cot-sean.yaml, place this alongside the gsm8k.yaml in lm_eval.
For offline validation loss calculations (as there is no training val loop), use multi_recurence_eval.py. Example command in bottom of python file.
NOTE: If you get an error like: "TypeError: ... got multiple values for keyword argument 'tie_word_embeddings'", remove the tie_word_embeddings key from the models config.json, as Huginn-0125 uses the tie_embeddings flag instead.
We provide plotting code in plot_evals.py, which is useful for plotting multiple experiments at once quickly. I have left an example of how I would plot my olmo runs here.
We provide the exact plotting code and data used in our paper in paper_plots/. Run paper_plots/plot.py to recreate the plots.
- To untie the embeddings and lm_head for non-recurrent Llama models before training use utils/untie_embeds_hf.py.
- For ShortGPT experiments we use https://github.com/sramshetty/ShortGPT.
Please, feel free to contact us with any questions, or open an issue on Github.