Parameter-efficient RL using prefix optimization. A small adapter model (1B-1.5B) is trained via PPO to generate the first k tokens (a prefix) that a larger frozen target model (7B-72B) uses to solve math problems. The reward is binary: did the target model get the right answer?
Built on OpenRLHF v0.6.1 (stripped to PPO-only).
This repo supports two training modes:
Note: This is a work in progress repository. We are actively updating it based on our original codebase. The goal is to showcase the core idea of Prefix-RL and how we configured it — expect rough edges.
- Direct Math RL - Train a model to solve math problems directly (for baseline comparison).
- Prefix-RL - Train a small adapter to generate short prefixes that help a large target model.
┌─────────────────────────────────────────────────┐
│ PPO Training Loop │
│ (Ray-distributed: Actor, Critic, Reference) │
│ │
│ Adapter Model (e.g., Qwen2.5-1.5B) │
│ → Generates prefix (k=64 tokens) │
└──────────────────┬──────────────────────────────┘
│ prefix
▼
┌─────────────────────────────────────────────────┐
│ Reward Model Server │
│ │
│ Direct Math: verify answer via math_verify │
│ Prefix-RL: Target model (7B+) + prefix │
│ → solve → binary reward │
└─────────────────────────────────────────────────┘
- Python 3.10+
- CUDA 12.4+
- At least 1 GPU (Direct Math RL) or 4+ GPUs (Prefix-RL with large target model)
On a GPU node (needed for flash-attn compilation):
salloc -p kempner_h100 --gres=gpu:1 -n 1 --time=4:00:00 --mem=128G -c 8 --account=kempner_sham_lab
module load python/3.10.13-fasrc01
module load cuda/12.9.1-fasrc01
module load cudnn/9.10.2.21_cuda12-fasrc01
conda create -n prefix_rl python=3.10 -y
conda activate prefix_rl
# PyTorch (CUDA 12.4)
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124
# vLLM (install before other deps to avoid version conflicts)
pip install vllm==0.7.2
# flash-attn (compiles from source, takes a few minutes)
pip install flash-attn --no-build-isolation
# Remaining dependencies
pip install -r requirements.txtOr use the setup script:
bash scripts/setup_env.shFor W&B logging:
wandb loginTrains a model to solve GSM8K math problems directly. Requires 1 GPU (for testing the pipeline).
python -m openrlhf.cli.serve_rm_math --port 5000 --dataset gsm8kray start --head --dashboard-port 8265
./scripts/train_ppo.sh configs/gsm8k_ppo_sweep.yamlTrains a small adapter (Qwen2.5-1.5B) to generate 64-token prefixes that help a large target model (Qwen2.5-7B+) solve MATH problems. Requires 4+ GPUs (for the target model + training).
On a node with GPUs for the target model:
python -m openrlhf.cli.serve_rm_hint_offline \
--config configs/hint_rm_config.yaml \
--port 5001Edit configs/hint_rm_config.yaml to set the target model and tensor_parallel_size.
On the training node (update remote_rm_url if the reward server is on a different node):
ray start --head --dashboard-port 8265
SWEEP_CONFIG=configs/hint_ppo_sweep.yaml python scripts/ppo_sweep.py sweep_config=configs/hint_ppo_sweep.yamlOr via SLURM:
sbatch scripts/run_ppo_sweep.sh# Start reward server on one node
sbatch scripts/serve_rm_math.sh # for direct math
sbatch scripts/serve_rm_hint.sh # for Prefix-RL
# Submit training (update remote_rm_url in config to point to reward server node)
sbatch scripts/run_ppo_sweep.shThe sweep script (scripts/ppo_sweep.py) reads the YAML config, computes a cartesian product of sweep parameters, and each SLURM array task runs one configuration.
| Parameter | Direct Math | Prefix-RL |
|---|---|---|
pretrain (adapter) |
Llama-3.2-1B | Qwen2.5-1.5B |
generate_max_len |
1024 | 64 |
n_samples_per_prompt |
1 | 8 |
init_kl_coef |
1e-3 | 1e-3 |
train_batch_size |
128 | 512 |
prompt_data |
openai/gsm8k | SIEH/MATH-filtered |
eval_data |
gsm8k | math |
Supported via --eval_data: gsm8k, math, aime_subset, aime2024, olympiadbench_physics, ocwcourses.
- Stripped to PPO-only — removed DPO, KTO, KD, RM, SFT, PRM trainers.
- Evaluation during training —
--eval_data/--eval_stepsflags with greedy generation + math verification. - Custom reward servers —
serve_rm_math.py(direct verification) andserve_rm_hint_offline.py(target model evaluation via vLLM). - Math verification —
math_verifier.pywith\boxed{}parsing,<llm-code>execution, subprocess sandboxing. - Temperature-scaled log probs — Actor model supports temperature parameter for PPO.
- Log-scale checkpointing —
--save_log_scale_countfor logarithmically-spaced saves.
If you find this work useful, please cite:
@inproceedings{rochafilho2026parameter,
title={Parameter-Efficient Reinforcement Learning using Prefix Optimization},
author={Rocha Filho, Itamar and Zhao, Rosie and Kakade, Sham M. and Malach, Eran and Jelassi, Samy},
booktitle={International Conference on Learning Representations (ICLR)},
year={2026}
}