Skip to content

starreeze/latr

Repository files navigation

Lookahead Tree-Based Rollouts for RLVR

This is the official implementation of the paper Lookahead Tree-Based Rollouts for Enhanced Trajectory-Level Exploration in Reinforcement Learning with Verifiable Rewards.

overview

Installation

pip install -r requirements.txt
# uv is also supported: 
# uv pip install -r requirements.txt

Run RL

Data

Please directly download our processed dataset from here and put the two directories (countdown-base, math-base) in dataset/.

Training

We implement on the standard VeRL-0.5.0 framework. Training scripts provided in scripts are ready to run directly, with names organized as {algorithm}-{rollout}-{dataset}.sh. For rollout name, vllm represents stochastic sampling, and kt represents lookahead Tree-based rollouts (LATR). Examples:

# run standard RL on countdown
bash scripts/dapo-vllm-countdown.sh

# run latr on countdown
bash scripts/dapo-kt-countdown.sh

The results will be logged to the console, EXPERIMENT_NAME.log, and wandb.

The kt scripts differ from vllm only in actor_rollout_ref.rollout.name (which calls different rollout workers in verl/workers/rollout) and additional kt arguments (those starting with actor_rollout_ref.rollout.kt), including:

  • max_n_branch_per_token: only top-m tokens are considered when branching (in addition to thresholding conditions in the paper)
  • prob_filter_abs_thres: $\tau_{abs}$ in the paper
  • prob_filter_rel_thres: $\tau_{rel}$
  • rollout_filter_edit_dist_thres: $\tau_{ed}$
  • rollout_filter_steps: $r$
  • mix_ratio_schedule: a dict of {training_step: $\eta$}, generally exponential decay but more flexible
  • return_nb_thres_decay/force_return_step: control the generation step to exit LATR and use stochastic sampling instead

Our experiments are conducted on 8xH200 GPUs with peak GPU memory usage around 130GB. For other hardware configurations, if you experience OOM during the training process:

  • For OOM in LATR generation process, please consider reducing generation batch size actor_rollout_ref.rollout.micro_batch_size at the sacrifice of efficiency. This significantly impacts the training speed.
  • For OOM in compute_logp or update_policy, please adjust the training parameters following official VeRL documentation, especially the batch size and gradient checkpointing. This only affects the training speed slightly.

Evaluation

We also use the VeRL framework for evaluation. Results can be obtained in two ways:

  1. At the last stage of training, evaluation is triggered automatically. The results are printed to the console and logged to both a log file and wandb.
  2. With VeRL checkpoints, you can run evaluation scripts provided in scripts, with names organized as eval-{dataset}.sh. Just use bash scripts/eval-{dataset}.sh EXPERIMENT_NAME to evaluate. The results are printed to the console and logged to a file in outputs/results/EXPERIMENT_NAME.log.

To obtain a csv from log files, use python tools/read_log.py path/to/log/file_1.log path/to/log/file_2.log .... Both of the above outputs are supported. The csv will be saved in ./metrics.csv.

Implementation Details

Overview

We propose a new rollout method LATR higher diversity. Generally, it is orthogonal to the policy update, so the workflow is the same as the original VeRL. The only difference is the rollout worker.

We implement the rollout worker in verl/workers/rollout/kt_rollout.py, replacing the original hf_rollout.py. The worker reads the rollout config and pass it to the generate method in model/generation.py, where the core sampling logic is implemented.

Code Structure

data/                # data processing and reward function
model/               # core sampling logic
├── generation.py    # the generate method
├── metrics.py       # metrics for pruning (edit distance, suffix match, rouge-l)
└── sample_filter.py # the filters for both branching and pruning
tools/               # utility functions
verl/                # the verl framework (with our modifications)

Modifications to VeRL

We modify the VeRL framework to support our rollout method, majorly including (note the line numbers may not be accurate due to constant updates):

  • verl/trainer/ppo/metric_utils.py:process_validation_metrics(): add response length as a validation metric.
  • verl/trainer/ppo/ray_trainer.py:l1195-1199,l1239-1253: implement logic for diversity filtering in ablation study.
  • verl/trainer/constants_ppo.py: the runtime environment for PPO training, which is modified to set the RAY_DEBUG environment variable to 1 to enable the debug mode.
  • verl/trainer/main_ppo.py:l205-208: inject reward function for countdown since there's no built-in ones.
  • verl/workers/fsdp_workers.py:l127: add timeout to the distributed initialization.
  • verl/workers/fsdp_workers.py:l499-506: the worker initialization, which is modified to initialize our own rollout worker when rollout.name is kt.
  • verl/workers/rollout/kt_rollout.py: the rollout worker, which is a subclass of BaseRollout in verl/workers/rollout/base.py. The worker reads the rollout config and pass it to the generate method in model/generation.py, where the core sampling logic is implemented.
  • verl/protocal.py:l698-701: modify the chunk method to allow non-uniform batch sizes.
  • verl/utils/reward_score/__init__.py:l43-45: extend reward function for math-500, amc23 and olympiad.

Citation

If you find this useful, please consider citing:

@misc{latr,
      title={Lookahead Tree-Based Rollouts for Enhanced Trajectory-Level Exploration in Reinforcement Learning with Verifiable Rewards}, 
      author={Shangyu Xing and Siyuan Wang and Chenyuan Yang and Xinyu Dai and Xiang Ren},
      year={2025},
      eprint={2510.24302},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2510.24302}, 
}

Acknowledgments

We thank the authors of the following projects for their contributions:

About

Lookahead Tree-Based Rollouts for RLVR

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published