Skip to content

ItamarRocha/Prefix-RL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Prefix-RL

Parameter-efficient RL using prefix optimization. A small adapter model (1B-1.5B) is trained via PPO to generate the first k tokens (a prefix) that a larger frozen target model (7B-72B) uses to solve math problems. The reward is binary: did the target model get the right answer?

Built on OpenRLHF v0.6.1 (stripped to PPO-only).

This repo supports two training modes:

Note: This is a work in progress repository. We are actively updating it based on our original codebase. The goal is to showcase the core idea of Prefix-RL and how we configured it — expect rough edges.

  1. Direct Math RL - Train a model to solve math problems directly (for baseline comparison).
  2. Prefix-RL - Train a small adapter to generate short prefixes that help a large target model.

Architecture

┌─────────────────────────────────────────────────┐
│                 PPO Training Loop               │
│  (Ray-distributed: Actor, Critic, Reference)    │
│                                                 │
│  Adapter Model (e.g., Qwen2.5-1.5B)            │
│  → Generates prefix (k=64 tokens)              │
└──────────────────┬──────────────────────────────┘
                   │ prefix
                   ▼
┌─────────────────────────────────────────────────┐
│            Reward Model Server                  │
│                                                 │
│  Direct Math: verify answer via math_verify     │
│  Prefix-RL:   Target model (7B+) + prefix      │
│             → solve → binary reward             │
└─────────────────────────────────────────────────┘

Setup

Prerequisites

  • Python 3.10+
  • CUDA 12.4+
  • At least 1 GPU (Direct Math RL) or 4+ GPUs (Prefix-RL with large target model)

Installation

On a GPU node (needed for flash-attn compilation):

salloc -p kempner_h100 --gres=gpu:1 -n 1 --time=4:00:00 --mem=128G -c 8 --account=kempner_sham_lab

module load python/3.10.13-fasrc01
module load cuda/12.9.1-fasrc01
module load cudnn/9.10.2.21_cuda12-fasrc01

conda create -n prefix_rl python=3.10 -y
conda activate prefix_rl

# PyTorch (CUDA 12.4)
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124

# vLLM (install before other deps to avoid version conflicts)
pip install vllm==0.7.2

# flash-attn (compiles from source, takes a few minutes)
pip install flash-attn --no-build-isolation

# Remaining dependencies
pip install -r requirements.txt

Or use the setup script:

bash scripts/setup_env.sh

For W&B logging:

wandb login

Quick Start: Direct Math RL (GSM8K)

Trains a model to solve GSM8K math problems directly. Requires 1 GPU (for testing the pipeline).

Step 1: Start the reward model server

python -m openrlhf.cli.serve_rm_math --port 5000 --dataset gsm8k

Step 2: Start Ray and run PPO training

ray start --head --dashboard-port 8265
./scripts/train_ppo.sh configs/gsm8k_ppo_sweep.yaml

Prefix-RL (MATH)

Trains a small adapter (Qwen2.5-1.5B) to generate 64-token prefixes that help a large target model (Qwen2.5-7B+) solve MATH problems. Requires 4+ GPUs (for the target model + training).

Step 1: Start the target model reward server

On a node with GPUs for the target model:

python -m openrlhf.cli.serve_rm_hint_offline \
    --config configs/hint_rm_config.yaml \
    --port 5001

Edit configs/hint_rm_config.yaml to set the target model and tensor_parallel_size.

Step 2: Run PPO training for the adapter

On the training node (update remote_rm_url if the reward server is on a different node):

ray start --head --dashboard-port 8265

SWEEP_CONFIG=configs/hint_ppo_sweep.yaml python scripts/ppo_sweep.py sweep_config=configs/hint_ppo_sweep.yaml

Or via SLURM:

sbatch scripts/run_ppo_sweep.sh

SLURM Usage

# Start reward server on one node
sbatch scripts/serve_rm_math.sh      # for direct math
sbatch scripts/serve_rm_hint.sh      # for Prefix-RL

# Submit training (update remote_rm_url in config to point to reward server node)
sbatch scripts/run_ppo_sweep.sh

The sweep script (scripts/ppo_sweep.py) reads the YAML config, computes a cartesian product of sweep parameters, and each SLURM array task runs one configuration.

Key Hyperparameters

Parameter Direct Math Prefix-RL
pretrain (adapter) Llama-3.2-1B Qwen2.5-1.5B
generate_max_len 1024 64
n_samples_per_prompt 1 8
init_kl_coef 1e-3 1e-3
train_batch_size 128 512
prompt_data openai/gsm8k SIEH/MATH-filtered
eval_data gsm8k math

Evaluation Datasets

Supported via --eval_data: gsm8k, math, aime_subset, aime2024, olympiadbench_physics, ocwcourses.

Changes from OpenRLHF v0.6.1

  • Stripped to PPO-only — removed DPO, KTO, KD, RM, SFT, PRM trainers.
  • Evaluation during training--eval_data/--eval_steps flags with greedy generation + math verification.
  • Custom reward serversserve_rm_math.py (direct verification) and serve_rm_hint_offline.py (target model evaluation via vLLM).
  • Math verificationmath_verifier.py with \boxed{} parsing, <llm-code> execution, subprocess sandboxing.
  • Temperature-scaled log probs — Actor model supports temperature parameter for PPO.
  • Log-scale checkpointing--save_log_scale_count for logarithmically-spaced saves.

Citation

If you find this work useful, please cite:

@inproceedings{rochafilho2026parameter,
  title={Parameter-Efficient Reinforcement Learning using Prefix Optimization},
  author={Rocha Filho, Itamar and Zhao, Rosie and Kakade, Sham M. and Malach, Eran and Jelassi, Samy},
  booktitle={International Conference on Learning Representations (ICLR)},
  year={2026}
}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors