This script runs OmegaPRM on multiple GPUs, with each GPU handling a different part of the dataset for parallel processing.
-
Split the Input Data
Use
process_json.pyto split your input JSON file into multiple parts for each GPU:python process_json.py --input_file questions.json --output_dir output_directory --num_splits 8
-
Run the Script
Use
run_omegaprm_multi_gpu.shto start processing with OmegaPRM on each GPU:run_omegaprm_multi_gpu.sh
Results are saved in
output_results.Note: Before running, make sure to set the correct values for parameters in the script. Important parameters include:
-
MODEL_NAME: Path to the model (e.g., Hugging Face or vLLM model).-
MODEL_TYPE: Set to "hf" for Hugging Face or "vllm" for vLLM support.-
Other parameterslikeMAX_NEW_TOKENS,TEMPERATURE,TOP_K, and other hyperparameters according to your needs.
-
CUDA_VISIBLE_DEVICES=0 python run_omegaprm.py \
--question_file ../extracted_problems_and_answers.json \
--output_dir output_results \
--model_name "Your model name or path " \
--model_type $MODEL_TYPE \
--device "cuda" \
--max_new_tokens 2048 \
--temperature 0.7 \
--top_k 30 \
--top_p 0.9 \
--c_puct 0.125 \
--alpha 0.5 \
--beta 0.9 \
--length_scale 500 \
--num_rollouts 16 \
--max_search_count 20 \
--rollout_budget 200 \
--save_data_tree True\
--log_file_prefix "log/omega_prm_single_gpu"
For data generated by OmegaPRM_v2, two formats are available:
-
Flat Format (
save_data_tree=False): Each entry is structured as:{ "solution_prefix": [Q, x_1:x_i], "mc_value": 0.5 }
where
iis a variable representing the number of reasoning steps. This format provides a linear view of the reasoning process without hierarchical structure. -
Tree Format (
save_data_tree=True): In this format, data is organized as a tree structure, aligned with the figure presented in the paper. Each reasoning step (or node) includes:- text: The cumulative reasoning from the root node up to this specific step.
- mc_value: The Monte Carlo score computed for the reasoning progression up to this step.
- children: A list of child nodes branching from the current node.