|
1 | | -<p align="center"> |
2 | | -</p> |
| 1 | +# KV Cache Compaction RL |
3 | 2 |
|
4 | | -<p align="center"> |
5 | | - <img src="https://github.com/user-attachments/assets/40c36e38-c5bd-4c5a-9cb3-f7b902cd155d#gh-light-mode-only" alt="Prime Intellect" width="312"> |
6 | | - <img src="https://github.com/user-attachments/assets/6414bc9b-126b-41ca-9307-9e982430cde8#gh-dark-mode-only" alt="Prime Intellect" width="312"> |
7 | | -</p> |
| 3 | +RL training with mid-generation KV cache compaction using [Attention Matching](https://arxiv.org/abs/2602.16284). This enables learning over long effective contexts (8K+ tokens) within fixed KV memory budgets (2K), by compressing the cache between generation segments. |
8 | 4 |
|
9 | | ---- |
| 5 | +Built on [PRIME-RL](https://github.com/PrimeIntellect-ai/prime-rl) for async RL and [vLLM](https://github.com/vllm-project/vllm) for inference. |
10 | 6 |
|
11 | | -<h3 align="center"> |
12 | | -PRIME-RL: Async RL Training at Scale |
13 | | -</h3> |
| 7 | +## How It Works |
14 | 8 |
|
15 | | ---- |
| 9 | +The `CompactionWorker` drives model forward passes inside vLLM's `collective_rpc`, bypassing the scheduler for full control over the KV cache. Generation proceeds until the KV cache fills a budget (`max_kv_len`), then: |
16 | 10 |
|
17 | | -</br> |
18 | | -<p align="center"> |
19 | | - <a href="https://github.com/PrimeIntellect-ai/prime-rl/actions/workflows/style.yaml"> |
20 | | - <img src="https://github.com/PrimeIntellect-ai/prime-rl/actions/workflows/style.yaml/badge.svg" alt="Style" /> |
21 | | - </a> |
22 | | - <a href="https://github.com/PrimeIntellect-ai/prime-rl/actions/workflows/cpu_tests.yaml"> |
23 | | - <img src="https://github.com/PrimeIntellect-ai/prime-rl/actions/workflows/cpu_tests.yaml/badge.svg" alt="Test" /> |
24 | | - </a> |
25 | | - <a href="https://github.com/PrimeIntellect-ai/prime-rl/actions/workflows/gpu_tests.yaml"> |
26 | | - <img src="https://github.com/PrimeIntellect-ai/prime-rl/actions/workflows/gpu_tests.yaml/badge.svg" alt="Test" /> |
27 | | - </a> |
28 | | -</p> |
| 11 | +1. **Select** top-k keys by attention importance (C1) |
| 12 | +2. **Solve** least-squares for replacement values (C2) |
| 13 | +3. **Optionally compute** NNLS beta bias for partition function correction |
| 14 | +4. **Inject** `[prompt | C1/C2 | suffix]` back into paged blocks |
| 15 | +5. **Continue** generating until `max_total_tokens` or EOS |
29 | 16 |
|
30 | | -## Overview |
31 | | - |
32 | | -PRIME-RL is a framework for large-scale asynchronous reinforcement learning. It is designed to be easy-to-use and hackable, yet capable of scaling to 1000+ GPUs. Beyond that, here is why we think you might like it: |
33 | | - |
34 | | -1. Integrates natively with [`verifiers`](https://github.com/PrimeIntellect-ai/verifiers) environments via the [Environments Hub](https://app.primeintellect.ai/dashboard/environments?ex_sort=most_stars) |
35 | | -2. Supports end-to-end post-training, including SFT and RL training and evals |
36 | | -3. Multi-node deployment with [FSDP2](https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html) training and [vLLM](https://github.com/vllm-project/vllm) inference backend |
37 | | -4. Designed for asynchronous agentic RL training at scale |
38 | | -5. Hackable, modular and extensible by nature |
| 17 | +With `compute_beta=true`, an additive per-key bias corrects the softmax partition function mismatch after compaction. This uses contiguous KV mirrors (BetaState) alongside vLLM's paged cache, with SDPA decode replacing FlashAttention for the bias-corrected path. |
39 | 18 |
|
40 | 19 | ## Setup |
41 | 20 |
|
42 | | -> *We develop and test on NVIDIA RTX 3090/4090/5090, A100, H100, H200, and B200. If your setup fails, please create an [issue](https://github.com/PrimeIntellect-ai/prime-rl/issues).* |
43 | | -
|
44 | | -### Prerequisites |
45 | | - |
46 | | -Currently, you **need at least one NVIDIA GPU to use PRIME-RL**. If you don't already have access to one, we recommend our [compute platform](https://app.primeintellect.ai) for everything from renting on-demand single GPUs for developing, debugging and small ablations, to [reserving 1000+ GPU clusters](https://app.primeintellect.ai/dashboard/quotes) for production-scale training. |
47 | | - |
48 | | -### Quick Setup |
49 | | - |
50 | | -Set up PRIME-RL in a single command. |
51 | | - |
52 | | -```bash |
53 | | -curl -sSL https://raw.githubusercontent.com/PrimeIntellect-ai/prime-rl/main/scripts/install.sh | bash |
54 | | -``` |
55 | | - |
56 | | -<details> |
57 | | -<summary> |
58 | | -Manual Setup |
59 | | -</summary> |
60 | | -<br> |
61 | | - |
62 | | -1. Clone the repository |
63 | | - |
64 | | -```bash |
65 | | -git clone https://github.com/PrimeIntellect-ai/prime-rl.git |
66 | | -cd prime-rl |
67 | | -``` |
68 | | - |
69 | | -2. Install [uv](https://docs.astral.sh/uv/) |
70 | | - |
71 | | -```bash |
72 | | -curl -LsSf https://astral.sh/uv/install.sh | sh |
73 | | -source $HOME/.local/bin/env |
74 | | -``` |
75 | | - |
76 | | -3. Install dependencies from the lock file |
77 | | - |
78 | 21 | ```bash |
| 22 | +git clone https://github.com/HyperPotatoNeo/attention-matching-rl.git |
| 23 | +cd attention-matching-rl |
79 | 24 | uv sync --all-extras |
80 | 25 | ``` |
81 | 26 |
|
82 | | -3.1. Optional: Install Flash Attention 3 (on Hopper GPUs only, for flash_attention_3 attention backend) |
| 27 | +## Inference |
83 | 28 |
|
84 | | -> *NOTE*: This step will take a while, as it builds the Flash Attention 3 extension from source, as it has no wheels prebuilt. |
85 | | -> *NOTE*: After this step, you can't run `uv sync --all-extras` or `uv run` as it will uninstall the package, you can avoid it by running `uv sync --inexact` or `uv run --no-sync` |
| 29 | +### Single-GPU server |
86 | 30 |
|
87 | 31 | ```bash |
88 | | -uv pip install "flash-attn-3 @ git+https://github.com/Dao-AILab/flash-attention.git@main#subdirectory=hopper" --no-build-isolation |
| 32 | +uv run inference @ configs/compaction/qwen3_4b_serve_tp1.toml --server.port 8000 |
89 | 33 | ``` |
90 | 34 |
|
91 | | -</details> |
92 | | - |
93 | | -<details> |
94 | | -<summary> |
95 | | -Validate your environment setup |
96 | | -</summary> |
97 | | -<br> |
| 35 | +### Test compaction generation |
98 | 36 |
|
99 | | -1. Check that the environment uses Python 3.12 |
| 37 | +```python |
| 38 | +import requests |
100 | 39 |
|
101 | | -```bash |
102 | | -uv run python -V |
| 40 | +resp = requests.post("http://localhost:8000/compact_generate", json={ |
| 41 | + "prompt_ids": tokenizer.encode("Solve this problem..."), |
| 42 | + "max_kv_len": 2048, |
| 43 | + "max_total_tokens": 8192, |
| 44 | + "compact_target_ratio": 0.25, |
| 45 | + "compact_window": 1024, |
| 46 | + "n_compacts": 99, |
| 47 | + "compute_beta": True, |
| 48 | + "temperature": 0.6, |
| 49 | +}) |
| 50 | +result = resp.json() |
| 51 | +print(result["final_text"]) |
| 52 | +print(f"Tokens: {result['diagnostics']['total_tokens']}, " |
| 53 | + f"Compactions: {len(result['diagnostics']['compaction_events'])}") |
103 | 54 | ``` |
104 | 55 |
|
105 | | -2. Check that `flash-attn` is installed |
| 56 | +### Evaluate on rg-mix-env |
106 | 57 |
|
107 | 58 | ```bash |
108 | | -uv run python -c "import flash_attn" |
109 | | -``` |
| 59 | +# Start 4 TP=1 servers |
| 60 | +bash scripts/start_4servers.sh |
110 | 61 |
|
111 | | -3. Check that you can run SFT trainer (*this requires 1 GPU*) |
| 62 | +# Run compaction eval |
| 63 | +python scripts/eval_rg_mix.py --mode compaction --n 100 \ |
| 64 | + --max-kv-len 2048 --max-total-tokens 8192 \ |
| 65 | + --n-compacts 99 --compact-ratio 0.25 --compact-window 1024 |
112 | 66 |
|
113 | | -```bash |
114 | | -uv run sft @ configs/debug/sft/train.toml |
| 67 | +# Run baseline eval (no compaction) |
| 68 | +python scripts/eval_rg_mix.py --mode baseline --n 100 |
115 | 69 | ``` |
116 | 70 |
|
117 | | -4. Check that you can run the RL trainer (*this requires 1 GPU*) |
| 71 | +## Training |
118 | 72 |
|
119 | | -```bash |
120 | | -uv run trainer @ configs/debug/rl/train.toml |
121 | | -``` |
| 73 | +Training uses 2 nodes in mixed mode: 5 inference GPUs + 3 trainer GPUs. |
122 | 74 |
|
123 | | -5. Check that you can run the inference server (*this requires 1 GPU*) |
| 75 | +**Architecture:** |
| 76 | +- **Node 1** (4 GPUs): 4 independent TP=1 compaction servers (ports 8000-8003) |
| 77 | +- **Node 2** (4 GPUs): 1 inference server on GPU 0 (port 8004) + FSDP2 trainer on GPUs 1-3 |
124 | 78 |
|
125 | | -```bash |
126 | | -uv run inference @ configs/debug/infer.toml |
127 | | -``` |
128 | | - |
129 | | -*Keep the inference server running in the background for the next steps.* |
| 79 | +### Config |
130 | 80 |
|
131 | | -5.1. Check that you can run the orchestrator against the inference server |
| 81 | +The training config at `configs/compaction/qwen3_4b_beta_test.toml` uses: |
| 82 | +- `compute_beta = true` for partition function correction |
| 83 | +- `max_kv_len = 2048`, `compact_window = 1024`, `ratio = 0.25` |
| 84 | +- `max_total_tokens = 8192` effective context per rollout |
| 85 | +- `batch_size = 256`, `rollouts_per_example = 8` |
| 86 | +- Full fine-tune, lr=1e-6, AdamW with CPU offload |
132 | 87 |
|
133 | | -```bash |
134 | | -uv run orchestrator @ configs/debug/orch.toml |
135 | | -``` |
| 88 | +### Launch |
136 | 89 |
|
137 | | -5.2. Check that you can run evals against the inference server |
| 90 | +1. **Create resolved config** — replace `__INFERENCE_NODE__` and `__TRAINER_NODE__` placeholders with actual hostnames |
| 91 | +2. **Node 1**: Run `scripts/start_4servers.sh` (or `multinode/compaction/node1_inference.sh`) |
| 92 | +3. **Node 2**: Run `multinode/compaction/node2_mixed.sh <resolved_config.toml>` |
138 | 93 |
|
139 | | -```bash |
140 | | -uv run eval @ configs/debug/eval.toml |
141 | | -``` |
| 94 | +The trainer on node 2 waits for the inference server on GPU 0 to become ready, then starts the RL loop. Weight updates are broadcast via the filesystem. |
142 | 95 |
|
143 | | -</details> |
| 96 | +## Key Files |
144 | 97 |
|
145 | | -### Additional Setup |
| 98 | +| File | Purpose | |
| 99 | +|------|---------| |
| 100 | +| `src/prime_rl/inference/compaction/worker.py` | Generation + compaction (single & batch) | |
| 101 | +| `src/prime_rl/inference/compaction/algorithm.py` | Attention Matching + NNLS beta solver | |
| 102 | +| `src/prime_rl/inference/compaction/beta_attention.py` | BetaState mirrors + SDPA decode with bias | |
| 103 | +| `src/prime_rl/inference/compaction/routes.py` | `/compact_generate` endpoint + auto-batching | |
| 104 | +| `src/compaction_env/env.py` | CompactionEnv (verifiers wrapper) | |
| 105 | +| `scripts/eval_rg_mix.py` | Evaluation script | |
146 | 106 |
|
147 | | -1. If you want to log your runs to [W&B](https://wandb.ai), log in |
| 107 | +## Configs |
148 | 108 |
|
149 | | -```bash |
150 | | -uv run wandb login |
151 | | -# Or set `export WANDB_API_KEY=...` |
152 | | -``` |
153 | | - |
154 | | -2. If you require gated/ private models or datasets from [HuggingFace](https://huggingface.co), log in |
155 | | - |
156 | | -```bash |
157 | | -uv run hf auth login |
158 | | -# Or set `export HF_TOKEN=...` |
159 | | -``` |
160 | | - |
161 | | -## Training Examples |
162 | | -We provide end-to-end training examples in the [`examples`](examples) directory to highlight features of the framework and guide you through the process of training your own models. |
163 | | -1. [**Reverse Text**](examples/reverse_text/README.md): Train `Qwen3-0.6B` to reverse a small chunk of text. Demonstrates tiny-scale single-turn SFT and RL training. Can be trained on a single consumer GPU in a few minutes, and is ideal for getting started. |
164 | | -2. [**Wordle**](examples/wordle/README.md): Train `Qwen3-1.7B` to play Wordle. A fun example of multi-turn SFT and RL training. Can be trained on a 2-4 H100 GPUs in a few hours. Ideal for exploring the multi-turn training capabilities of the framework. |
165 | | -3. [**Alphabet Sort**](examples/alphabet_sort/README.md): Train `Qwen3-4B-Instruct-2507` to sort names alphabetically. Demonstrates multi-turn RL training via LoRA without SFT warmup. Can be trained on a single H100 GPU in just over an hour. Ideal for exploring LoRA-based training. |
166 | | -4. [**Wiki Search**](examples/wiki_search/README.md): Train `Qwen3-4B-Instruct-2507` to answer trivia questions by searching through a Wikipedia. Demonstrates multi-turn with web search tool use. |
167 | | -5. [**Hendrycks Sanity**](examples/hendrycks_sanity/README.md): Run a sanity check experiment on `DeepSeek-R1-Distill-Qwen-1.5B` using a filtered subset of MATH where the model already partially solves 20-80% of problems. Useful for algorithm ablations. |
168 | | - |
169 | | -*More to come...* |
| 109 | +| Config | Purpose | |
| 110 | +|--------|---------| |
| 111 | +| `qwen3_4b_beta_test.toml` | Beta attention training (5 steps, for testing) | |
| 112 | +| `qwen3_4b_fullft_train.toml` | Full fine-tune training (production) | |
| 113 | +| `qwen3_4b_serve_tp1.toml` | TP=1 compaction server | |
| 114 | +| `qwen3_4b_baseline.toml` | TP=4 baseline (no compaction) | |
170 | 115 |
|
171 | 116 | ## Docs |
172 | 117 |
|
173 | | -Check out the [docs](docs) directory for in-depth guides on how to use PRIME-RL. |
174 | | - |
175 | | -- [**Entrypoints**](docs/entrypoints.md) - Overview of the main components (orchestrator, trainer, inference) and how to run SFT, RL, and evals |
176 | | -- [**Configs**](docs/configs.md) - Configuration system using TOML files, CLI arguments, and environment variables |
177 | | -- [**Environments**](docs/environments.md) - Installing and using verifiers environments from the Environments Hub |
178 | | -- [**Async Training**](docs/async.md) - Understanding asynchronous off-policy training and step semantics |
179 | | -- [**Logging**](docs/logging.md) - Logging with loguru, torchrun, and Weights & Biases |
180 | | -- [**Checkpointing**](docs/checkpointing.md) - Saving and resuming training from checkpoints |
181 | | -- [**Benchmarking**](docs/benchmarking.md) - Performance benchmarking and throughput measurement |
182 | | -- [**Deployment**](docs/deployment.md) - Training deployment on single-GPU, multi-GPU, and multi-node clusters |
183 | | -- [**Troubleshooting**](docs/troubleshooting.md) - Common issues and their solutions |
184 | | - |
185 | | -## Contributing |
186 | | - |
187 | | -We warmly welcome community contributions! We use [issues](https://github.com/PrimeIntellect-ai/prime-rl/issues) to track bugs, feature requests, and share our internal roadmap. If you encounter bugs, have pain points during development, or have ideas for new features, please open an issue. |
188 | | - |
189 | | -Contributions are welcome via PR. Please follow these guidelines: |
190 | | -1. Install the [pre-commit hooks](#pre-commit-hooks) to ensure your code is formatted correctly. |
191 | | -2. Please keep your PR in "Draft" until it is ready for review. |
192 | | -3. If your PR resolves an issue, please link the issue in the PR description |
193 | | -4. If you can, try running the [test suite](#tests) locally to ensure your changes are working as expected. |
194 | | - |
195 | | -### Pre-Commit Hooks |
196 | | - |
197 | | -Please install the [pre-commit](https://pre-commit.com) hooks to ensure your code is formatted correctly. |
198 | | - |
199 | | -```bash |
200 | | -uv run pre-commit install |
201 | | -``` |
202 | | - |
203 | | -### Tests |
204 | | - |
205 | | -Run the full test suite |
206 | | - |
207 | | -```bash |
208 | | -uv run pytest -v |
209 | | -``` |
210 | | - |
211 | | -To run unit tests, run |
212 | | - |
213 | | -```bash |
214 | | -uv run pytest tests/unit -v |
215 | | -``` |
216 | | - |
217 | | -To run integration tests, run |
218 | | - |
219 | | -```bash |
220 | | -uv run pytest tests/integration -v |
221 | | -``` |
222 | | - |
223 | | -To run CPU-only tests, use the inverse of the `gpu` marker: |
224 | | - |
225 | | -```bash |
226 | | -uv run pytest -v -m "not gpu" |
227 | | -``` |
228 | | - |
229 | | -## License |
230 | | - |
231 | | -This project is licensed under the Apache 2.0 license, as found in the [License](LICENSE) file. |
| 118 | +- [Implementation details](docs/compaction/IMPLEMENTATION.md) — algorithm, beta correction, CUDA graphs |
| 119 | +- [Speed optimizations](docs/compaction/SPEED_OPTIMIZATION.md) — batching, graph capture, profiling |
232 | 120 |
|
233 | 121 | ## Citation |
234 | 122 |
|
235 | | -If you find our work useful, feel free to cite it using |
| 123 | +Based on [Attention Matching](https://arxiv.org/abs/2602.16284): |
236 | 124 |
|
237 | | -```tex |
238 | | -@misc{primeintellect2025prime-rl, |
239 | | - author = {Prime Intellect}, |
240 | | - title = {PRIME-RL}, |
241 | | - url = {https://github.com/PrimeIntellect-ai/prime-rl}, |
242 | | - year = {2025} |
| 125 | +```bibtex |
| 126 | +@article{zweiger2025attention, |
| 127 | + title={Attention Matching: an Attention Decomposition Framework for Efficient KV Cache Compression}, |
| 128 | + author={Zweiger, Adam}, |
| 129 | + journal={arXiv preprint arXiv:2602.16284}, |
| 130 | + year={2025} |
243 | 131 | } |
244 | 132 | ``` |
0 commit comments