Repo for "Critical Tokens Matter: Token-Level Contrastive Estimation Enhances LLM's Reasoning Capability"
- All the code from our paper has been released.
- Accepted to ICML 2025 🎉🎉🎉
- Added support for Qwen2.5-7B/32B.
- Released our training data. To train models, please download the data and move it to the ./data directory.
Figure 1: An illustration of the critical token "owed" shows that it fails to lead to the correct answer in any case.
Replacing it with an alternative can significantly increase model accuracy.
Mathematical reasoning tasks pose significant challenges for large language models (LLMs) because they require precise logical deduction and sequence analysis. In this work, we introduce the concept of critical tokens -- elements within reasoning trajectories that significantly influence incorrect outcomes. We present a novel framework for identifying these tokens through rollout sampling and demonstrate their substantial divergence from traditional error tokens. Through extensive experiments on datasets such as GSM8K and MATH500, we show that identifying and replacing critical tokens significantly improves model accuracy.
Figure 2: Impact of critical tokens on reasoning accuracy. Replacing critical tokens with alternatives ("w/o Critical Tokens")
can significantly increase model accuracy on both GSM8K and MATH500.
We propose an efficient methodology for pinpointing these tokens in large-scale datasets using contrastive estimation and extend this framework to enhance model training processes with direct preference optimization (DPO). Experimental results on GSM8K and MATH500 benchmarks with the widely used models Llama-3 (8B and 70B) and Deepseek-math (7B) demonstrate the effectiveness of the proposed approach, cDPO. Our results underscore the potential of leveraging critical tokens to reduce errors in reasoning tasks, advancing the development of AI systems capable of robust logical deduction.
Cloning the repository
git clone [email protected]:chenzhiling9954/Critical-Tokens-Matter.git
cd critical_token_release/srcPreparing conda env
conda create -n critical_token python=3.10
conda activate critical_token
pip install -r requirements.txtThe training pipeline includes two major tasks: Training for CE and Training for cDPO. Both tasks use LoRA for training.
To train the CE model (which includes both the positive and negative models), run the following command:
gpus=<gpus>
export CUDA_VISIBLE_DEVICES=$gpus
python pipeline.py \
--task_name train_ce \
--model_name <model_name> \
--dataset_name <dataset_name> \
--gpus $gpusgpus=<gpus>
export CUDA_VISIBLE_DEVICES=$gpus
python pipeline.py \
--task_name train_cdpo \
--model_name <model_name> \
--dataset_name <dataset_name> \
--gpus $gpus- The
model_namecan be one of the following options:Meta-Llama-3-8B,Meta-Llama-3-70B, ordeepseek-math-7b-base. - The
dataset_namecan be one of the following options:GSM8KorMATH.
You can test the LoRA training results with the following command:
gpus=<gpus>
export CUDA_VISIBLE_DEVICES=$gpus
python pipeline.py \
--task_name evaluation \
--model_name <model_name> \
--dataset_name <dataset_name> \
--lora_path <lora_path>
--gpus $gpusAlternatively, to test an existing model, use the command below:
gpus=<gpus>
export CUDA_VISIBLE_DEVICES=$gpus
python pipeline.py \
--task_name evaluation \
--model_name <model_name> \
--dataset_name <dataset_name> \
--lora_path <lora_path> \
--model_path <model_path>
--gpus $gpus- The
model_namecan be one of the following options:Meta-Llama-3-8B,Meta-Llama-3-70B, ordeepseek-math-7b-base. - The
dataset_namecan be one of the following options:GSM8KorMATH.
In the Rollout Sampling process, we conducted 64 rollout samplings for each token within an incorrect trajectory. For each token, we calculated a score based on the correctness ratio of the generated completions to quantify its influence on the overall trajectory. The goal is to identify tokens that have a critical impact on the model's output.
gpus=<gpus>
export CUDA_VISIBLE_DEVICES=$gpus
python pipeline.py \
--task_name sampling \
--model_name <model_name> \
--dataset_name <dataset_name> \
--cdpo_data_path <cdpo_data_path> \
--gpus $gpus- The
model_namecan be one of the following options:Meta-Llama-3-8B,Meta-Llama-3-70B, ordeepseek-math-7b-base. - The
dataset_namecan be one of the following options:GSM8KorMATH. - The
cdpo_data_pathspecifies the source data path for sampling and should follow the format./data/{model_name}/{dataset_name}.dpo_top_k{k}.T{t}.json.For example,./data/Meta-Llama-3-8B/GSM8K.dpo_top_k1.T0.50.json.
After sampling, we check the results to calculate performance metrics such as "Pass@K", which helps quantify how effective the critical tokens are in influencing the model's predictions.
If you find this repository helpful, please consider citing our paper:
@article{lin2024critical,
title={Critical Tokens Matter: Token-Level Contrastive Estimation Enhence LLM's Reasoning Capability},
author={Lin, Zicheng and Liang, Tian and Xu, Jiahao and Wang, Xing and Luo, Ruilin and Shi, Chufan and Li, Siheng and Yang, Yujiu and Tu, Zhaopeng},
journal={arXiv preprint arXiv:2411.19943},
year={2024}
}