Skip to content

microsoft/SuperRL

SuperRL: Reinforcement Learning with Supervision to Boost Language Model Reasoning

Introduction

SuperRL is a unified training framework that adaptively alternates between reinforcement learning and supervised fine-tuning for training large language models, built on top of verl.

⚠️ Important Notice: This repository only provides core SuperRL component implementations, not a complete standalone training framework. These components must be integrated with the original verl codebase and placed in the corresponding verl project locations to function properly.

📄 Paper: This repository contains the open-source implementation of SuperRL: Reinforcement Learning with Supervision to Boost Language Model Reasoning (arXiv:2506.01096)
Authors: Yihao Liu, Shuocheng Li, Lang Cao, Yuhang Xie, Mengyu Zhou, Haoyu Dong, Xiaojun Ma, Shi Han, Dongmei Zhang

🎯 Overview

SuperRL addresses the challenge of sparse rewards in reinforcement learning by adaptively switching between RL and SFT modes. When every rollout for a given instance receives zero reward (indicating the absence of a learning signal), SuperRL falls back to SFT on curated offline data.

SuperRL Framework

Key Benefits

  • Higher Sample Efficiency: Better utilization of both online exploration and offline demonstrations
  • Stronger Generalization: Improved performance across diverse reasoning benchmarks
  • Improved Robustness: Enhanced stability under sparse reward conditions

🚀 Features

  • Enhanced Actor Implementations: SuperRL, HybridAdvGated, and HybridLogSigma actors
  • Hybrid Dataset Support: Custom dataset class with special handling for tagged_answer and superrl_research_response_ids
  • Dataset Preprocessing: Standardized preprocessing for mathematical reasoning datasets
  • Universal Reward Functions: Comprehensive reward system for reasoning tasks
  • Built on verl v0.5.0: Leverages the latest verl features

📁 Repository Structure

SuperRL-Opensource/
├── actor/                      # Enhanced actor implementations
│   ├── SuperRLActor.py        # Main SuperRL actor with OR logic
│   ├── HybridAdvGatedActor.py  # Hybrid actor with advantage gating
│   └── HybridLogSigmaActor.py  # Hybrid actor with log-sigma approach
├── dataset/                    # Dataset components for verl integration
│   ├── hybrid_dataset.py       # HybridDataset class with SuperRL features
│   ├── response_dataproto_compose.py  # Response data composition utilities
│   └── __init__.py             # Dataset module initialization
├── data_preprocess/            # Dataset preprocessing utilities
│   ├── common_utils.py         # Shared utilities for dataset processing
│   ├── aime2024_preprocess.py  # AIME 2024 dataset preprocessing
│   ├── aime2025_preprocess.py  # AIME 2025 dataset preprocessing
│   ├── gsm8k_preprocess.py     # GSM8K dataset preprocessing
│   ├── hitab_preprocess.py     # HiTab dataset preprocessing
│   ├── limo_preprocess.py      # LIMO dataset preprocessing
│   ├── metamath_preprocess.py  # MetaMath dataset preprocessing
│   ├── openr1_preprocess.py    # OpenR1 dataset preprocessing
│   ├── prm12k_preprocess.py    # PRM12K dataset preprocessing
│   ├── batch_process.py        # Batch processing utilities
│   └── preview_parquet.py      # Dataset preview utilities
├── demo_script/                # Demo and example scripts
│   └── superrl_demo.sh         # SuperRL training demo script
├── figure/                     # Documentation figures and diagrams
│   └── SuperRLFramework.png    # SuperRL framework illustration
├── reward/                     # Reward function implementations
│   └── superrl.py              # Universal reward function for reasoning tasks
├── CODE_OF_CONDUCT.md          # Code of conduct guidelines
├── CONTRIBUTING.md             # Contribution guidelines
├── LICENSE.txt                 # License information
├── README.md                   # Project documentation
├── SECURITY.md                 # Security policy
└── SUPPORT.md                  # Support information

🛠️ Quick Start

Prerequisites

This project is an extension component for verl and cannot run independently. Please complete the verl installation and setup first.

Make sure you have verl v0.5.0 or later installed:

# 1. First install and set up the complete verl environment
git clone --branch v0.5.0 https://github.com/volcengine/verl.git
cd verl

# 2. Clone this SuperRL components repository
git clone https://github.com/your-username/SuperRL-Opensource.git

⚠️ Version Compatibility Note: Different versions of verl may require modifications to the actor implementations. If you encounter compatibility issues when using newer or older versions of verl, you may need to adapt the actor code to match the specific verl version's API and interface requirements.

🔧 Integration Guide

📋 Note: This guide provides step-by-step instructions for integrating SuperRL components into the verl project. All operations should be performed in the verl project root directory.

1. Reward Function Integration

Step 1: Copy SuperRL reward function to verl directory

# Copy SuperRL reward function to verl's reward_score directory
cp SuperRL-Opensource/reward/superrl.py verl/utils/reward_score/superrl.py

Step 2: Modify verl's reward_score/init.py file

Add SuperRL reward function support to the verl/utils/reward_score/__init__.py file. Add the following to the default_compute_score function:

# Add to the default_compute_score function
elif data_source in ["openai/gsm8k", "lighteval/MATH", "DigitalLearningGmbH/MATH-lighteval", 
                     "HuggingFaceH4/MATH-500", "GAIR/LIMO", "meta-math/MetaMathQA", 
                     "open-r1/OpenR1-Math-220k", "horseee/MixChain-Z-PRM12K", "hitab"]:
    # Use SuperRL's universal reward function
    from . import superrl
    res = superrl.compute_score(solution_str, ground_truth)

Step 3: Ensure correct reward function return format

SuperRL's reward function already returns float type, which is compatible with verl standards.

2. Actor Implementation Integration

Step 1: Copy SuperRL actors to verl workers directory

# Copy all SuperRL actor implementations
cp SuperRL-Opensource/actor/* verl/workers/actor/

Step 2: Modify fsdp_workers.py to support SuperRL Actor

Modify the actor initialization part in the init_model method of verl/workers/fsdp_workers.py:

# Modify in the init_model method of ActorRolloutRefWorker class
if self._is_actor:
    actor_cfg = omega_conf_to_dataclass(self.config.actor)
    
    # Check if using SuperRL Actor
    actor_type = self.config.actor.get("actor_type", "default")
    
    if actor_type == "superrl":
        from verl.workers.actor.SuperRLActor import SuperRLActor
        self.actor = SuperRLActor(
            config=actor_cfg, 
            actor_module=self.actor_module_fsdp, 
            tokenizer=self.tokenizer,
            actor_optimizer=self.actor_optimizer
        )
    elif actor_type == "hybrid_adv_gated":
        from verl.workers.actor.HybridAdvGatedActor import HybridAdvGatedActor
        self.actor = HybridAdvGatedActor(
            config=actor_cfg, 
            actor_module=self.actor_module_fsdp, 
            tokenizer=self.tokenizer,
            actor_optimizer=self.actor_optimizer
        )
    elif actor_type == "hybrid_log_sigma":
        from verl.workers.actor.HybridLogSigmaActor import HybridLogSigmaActor
        self.actor = HybridLogSigmaActor(
            config=actor_cfg, 
            actor_module=self.actor_module_fsdp, 
            tokenizer=self.tokenizer,
            actor_optimizer=self.actor_optimizer
        )
    else:
        # Default to original DataParallelPPOActor
        from verl.workers.actor import DataParallelPPOActor
        self.actor = DataParallelPPOActor(
            config=actor_cfg, 
            actor_module=self.actor_module_fsdp, 
            actor_optimizer=self.actor_optimizer
        )

Step 3: Configuration file support

Add actor_type parameter to training configuration:

# Add to training configuration file
actor_rollout_ref:
  actor:
    actor_type: "default"  # Options: "superrl", "hybrid_adv_gated", "hybrid_log_sigma", "default"
    # SuperRL specific configurations
    sft_micro_batch_size: 2
    pg_signal_eps: 1e-8
    reward_eps: 1e-8
    sft_label_smoothing: 0.0

# Add HybridDataset configuration
data:
  use_hybrid_dataset: false  # Enable HybridDataset for SuperRL
  # Other data configurations...

3. Dataset Integration

Step 1: Copy HybridDataset to verl's dataset module

# Copy dataset components to verl's dataset module
cp SuperRL-Opensource/dataset/hybrid_dataset.py verl/verl/utils/dataset/
cp SuperRL-Opensource/dataset/response_dataproto_compose.py verl/verl/utils/dataset/

Step 2: Add imports to verl's dataset init.py

Add the following imports to verl/verl/utils/dataset/__init__.py:

# Add these imports to verl/verl/utils/dataset/__init__.py
from .hybrid_dataset import HybridDataset
from .response_dataproto_compose import *

Step 3: Modify main_ppo.py to support HybridDataset

Add support for HybridDataset in the create_rl_dataset function in verl/verl/trainer/main_ppo.py:

def create_rl_dataset(data_config, tokenizer, processor=None):
    """Create RL dataset with support for HybridDataset"""
    
    # Determine dataset class based on configuration or data characteristics
    use_hybrid_dataset = data_config.get("use_hybrid_dataset", False)  # Enable by default for SuperRL
    
    if use_hybrid_dataset:
        from verl.utils.dataset import HybridDataset
        dataset_cls = HybridDataset
    else:
        from verl.data.rlhf_dataset import RLHFDataset  # Default verl dataset
        dataset_cls = RLHFDataset
    
    # Extract data file paths
    data_paths = []
    if hasattr(data_config, 'train_files') and data_config.train_files:
        if isinstance(data_config.train_files, (list, ListConfig)):
            data_paths.extend(data_config.train_files)
        else:
            data_paths.append(data_config.train_files)
    
    # Instantiate the dataset using the determined dataset class
    dataset = dataset_cls(
        data_files=data_paths,
        tokenizer=tokenizer,
        processor=processor,
        config=data_config,
    )
    
    return dataset

Step 3: Modify ray_trainer.py to handle additional batch keys

Update the batch processing in verl/verl/trainer/ray_trainer.py to include the new keys:

# In the ray_trainer.py file, find the batch_keys_to_pop definition and update it:
batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
non_tensor_batch_keys_to_pop = ["raw_prompt_ids", "tagged_answer"]

# Make sure the batch processing logic can handle these additional keys
def process_batch(batch):
    # Extract tensor keys
    for key in batch_keys_to_pop:
        if key in batch:
            # Process tensor keys as before
            pass
    
    # Extract non-tensor keys 
    for key in non_tensor_batch_keys_to_pop:
        if key in batch:
            # Handle non-tensor keys (like tagged_answer)
            # These should be preserved in non_tensor_batch
            pass

4. Data Preprocessing Setup

Step 1: Copy data preprocessing scripts

# Copy data preprocessing scripts to a suitable location
mkdir -p data_preprocessing
cp SuperRL-Opensource/data_preprocess/* data_preprocessing/

5. Training Script Modifications

Step 1: Modify demo script to use SuperRL

Update demo_script/superrl_demo.sh:

python3 -m verl.trainer.main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files="$train_files" \
    data.val_files="$test_files" \

    # Using hybrid dataset for SuperRL
    data.use_hybrid_dataset=true \

    # ... keep other configurations unchanged ...
    
    # Add SuperRL specific configurations
    actor_rollout_ref.actor.actor_type=superrl \
    actor_rollout_ref.actor.sft_micro_batch_size=2 \
    actor_rollout_ref.actor.pg_signal_eps=1e-8 \
    actor_rollout_ref.actor.reward_eps=1e-8 \
    
    # Other configurations...
    trainer.total_epochs=100 $@

Citation

If you find this repository useful, please considering giving ⭐ or citing:

@misc{liu2025superrlreinforcementlearningsupervision,
      title={SuperRL: Reinforcement Learning with Supervision to Boost Language Model Reasoning}, 
      author={Yihao Liu and Shuocheng Li and Lang Cao and Yuhang Xie and Mengyu Zhou and Haoyu Dong and Xiaojun Ma and Shi Han and Dongmei Zhang},
      year={2025},
      eprint={2506.01096},
      archivePrefix={arXiv},
      primaryClass={cs.AI},
      url={https://arxiv.org/abs/2506.01096}, 
}

Contributing

This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.

When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA.

This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact [email protected] with any additional questions or comments.

Trademarks

This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow Microsoft's Trademark & Brand Guidelines. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party's policies.

About

No description, website, or topics provided.

Resources

License

MIT, MIT licenses found

Licenses found

MIT
LICENSE
MIT
LICENSE.txt

Code of conduct

Contributing

Security policy

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published