Skip to content

ZhaolinGao/A-PO

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Accelerating RL for LLM Reasoning with Optimal Advantage Regression


Github arXiv Hugging Face Collection

Recent advances in LLMs, including OpenAI-o1 and DeepSeekR1, have demonstrated the remarkable effectiveness of reinforcement learning (RL) with rule-based rewards. However, methods like GRPO and PPO require explicit critics or multiple generations per prompt, resulting in high computational and memory costs.

Can we develop simpler and more efficient RL algorithms for long context reasoning?


🔥 $A^\star$-POPolicy Optimization via Optimal Advantage Regression

A new RL algorithm for LLMs that first estimates the optimal value function offline via sampling from reference policy, then performs on-policy updates with just 1 generation per prompt.

A*-PO Figure 1

⚡ Matches or beats PPO/GRPO while reducing training time by up to 2× and peak memory usage by over 30%.


Installation

conda create -n apo python=3.10
conda activate apo
pip3 install vllm==0.6.3 # or you can install 0.5.4, 0.4.2 and 0.3.1
pip install -e . # verl
pip3 install flash-attn --no-build-isolation # flash attention 2
pip install wandb

Datasets

Preprocessing

# gsm8k
python ./preprocess/data_preprocess/gsm8k.py
# math
python ./preprocess/data_preprocess/math.py

Offline Generation

To estimate the optimal value function, we generate 8 responses per prompt with the reference model and gather the reward. Our generated data can be found on huggingface:

Qwen2.5-1.5B GSM8K MATH
Qwen2.5-3B GSM8K MATH
Qwen2.5-7B GSM8K MATH

If you want to process and generate your own data, you can try the following scripts:

# gsm8k
python ./preprocess/data_generation/model_generate.py --dataset ~/data/gsm8k/train.parquet --remote_dir REMOTE_HUGGINGFACE_DATACARD --reward_function gsm8k
# math
python ./preprocess/data_generation/model_generate.py --dataset ~/data/math/train.parquet --remote_dir REMOTE_HUGGINGFACE_DATACARD --reward_function math

Training

# gsm8k
./scripts/apo_gsm8k.sh
# math
./scripts/apo_math.sh

The following are some important hyperparameters used in verl/trainer/config/apo_train.yaml:

Hyperparameter Description Value
data.num_gen_to_use Number of responses to use for value estimation in stage 1 8
data.beta1 $\beta_1$ for value estimation in stage 1 0.5
algorithm.beta2 $\beta_2$ for least-squared regressions in stage 2 1e-3

To save your model to huggingface, you can replace trainer.default_hub_dir with any huggingface repo to enable model uploading.

Our trained models on MATH can be found at:

Qwen2.5-1.5B A*PO PPO GRPO REBEL
Qwen2.5-3B A*PO PPO GRPO REBEL
Qwen2.5-7B A*PO PPO GRPO REBEL

Evaluations on MATH500, Minerva Math, Olympiad Bench, and AMC 23:

A*-PO Evaluations

Acknowledgements

Our pipeline is built based on TinyZero and verl.

Citing $A^\star$-PO

If you find $A^\star$-PO useful in your research, please consider citing our paper:

@misc{brantley2025acceleratingrlllmreasoning,
      title={Accelerating RL for LLM Reasoning with Optimal Advantage Regression}, 
      author={Kianté Brantley and Mingyu Chen and Zhaolin Gao and Jason D. Lee and Wen Sun and Wenhao Zhan and Xuezhou Zhang},
      year={2025},
      eprint={2505.20686},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2505.20686}, 
}

About

Accelerating RL for LLM Reasoning with Optimal Advantage Regression

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published