Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,15 @@ server_gpu*.log
server.log
*.output
skills/multinode/

*.toml
*.json
*.png
*.md
*.sbatch
*.sh

# Allow tracked source configs and scripts
!configs/**/*.toml
!scripts/*.sbatch
!scripts/*.sh
102 changes: 50 additions & 52 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,20 @@ the full technical walkthrough.
|------|---------|
| `src/prime_rl/inference/compaction/worker.py` | Generation + compaction logic (single & batch) |
| `src/prime_rl/inference/compaction/routes.py` | `/compact_generate` endpoint + auto-batching |
| `src/prime_rl/inference/compaction/algorithm.py` | Attention Matching + NNLS beta solver (suffix queries + forced indices by default) |
| `src/prime_rl/inference/compaction/algorithm.py` | Attention Matching + NNLS beta solver |
| `src/prime_rl/inference/compaction/beta_attention.py` | BetaState mirrors + SDPA decode with per-token bias |
| `src/prime_rl/trainer/rl/compaction.py` | Beta training hooks + deterministic compaction replay + query capture hooks + forced indices |
| `src/compaction_env/env.py` | CompactionEnv (verifiers SingleTurnEnv wrapper) |
| `scripts/eval_rg_mix.py` | rg-mix-env evaluation (compaction and baseline modes) |

### Configs

| Config | Purpose |
|--------|---------|
| `configs/compaction/qwen3_4b_fullft_suffix_queries.toml` | **Default** — 4+4 layout, suffix queries + forced indices, no beta |
| `configs/compaction/qwen3_4b_fullft_determ_nobeta.toml` | Random queries, deterministic compaction, no beta |
| `configs/compaction/qwen3_4b_fullft_determ_suffix.toml` | Deterministic random queries + prompt keys |
| `configs/compaction/qwen3_4b_fullft_fixed_1024q.toml` | 1024 random queries |
| `configs/compaction/qwen3_4b_fullft_nobeta.toml` | 4+4 layout, no beta (pre-deterministic, legacy) |
| `scripts/eval_rg_mix.py` | rg-mix-env evaluation (compaction, baseline, and RSA modes) |
| `scripts/eval_aime_rsa.py` | AIME benchmark for RSA vs baseline comparison |
| `scripts/eval_balrog_babyai.py` | BabyAI (MiniGrid) multi-turn eval (compaction, baseline, markovian) |
| `configs/compaction/qwen3_4b_balrog_babyai.toml` | BabyAI baseline training config |
| `scripts/start_4servers.sh` | Launch 4 TP=1 servers for DP=4 |
| `configs/compaction/qwen3_4b_fullft_train.toml` | **Default training config** — 2-node, mixed-mode |
| `configs/compaction/qwen3_4b_beta_test.toml` | Beta attention test config |
| `configs/compaction/qwen3_4b_fullft_baseline.toml` | Baseline (no compaction) |
| `configs/compaction/qwen3_4b_serve_tp1.toml` | TP=1 inference server |
| `configs/compaction/qwen3_4b_markovian_test.toml` | Markovian mode — Qwen3-4B, 50 steps |
| `configs/compaction/qwen3_06b_markovian_test.toml` | Markovian mode — Qwen3-0.6B, fast E2E test |
| `configs/compaction/qwen3_4b_serve_tp1.toml` | TP=1 server config (compaction) |
| `configs/compaction/qwen3_06b_serve_tp1.toml` | TP=1 server config (0.6B) |

### How it works

Expand All @@ -39,36 +35,16 @@ is compacted using Attention Matching: select top-k keys by attention importance
solve least-squares for replacement values (C2), optionally compute NNLS beta bias for
partition function correction, then inject `[prompt | C1/C2 | suffix]` back into paged blocks.

**Suffix queries + forced indices (default)**: Inference uses real suffix token attention
queries to score key importance (`use_suffix_queries=true`). Since vLLM and HuggingFace
produce numerically different query vectors, inference returns per-event top-k indices in
`diagnostics.compaction_indices`. The trainer passes these as `forced_indices` to
`compact_kv()`, skipping importance scoring and guaranteeing identical key selection. C2
values are recomputed from the trainer's KV cache using its own suffix queries for correct
gradients. This combination is the default because suffix queries provide +3% accuracy at
extreme compression (1024→32) while forced indices eliminate key selection mismatch.

**Full-context scoring**: `compact_kv()` scores key importance using the full KV cache
(prompt + assistant keys in the softmax denominator). Window keys redundant with prompt
content score lower.

**Deterministic compaction** (alternative): `compact_kv()` can use seeded random queries
(`seed + layer_idx`) instead of suffix queries, ensuring identical compaction between
inference and training replay without needing forced indices.

**Beta attention**: When `compute_beta=true`, the NNLS solver finds per-key additive biases
that correct the partition function mismatch between full and compacted attention.
`BetaState` maintains contiguous KV mirrors alongside paged cache. Decode switches from
FlashAttention to SDPA+beta via monkey-patched attention layers. Two separate CUDA graph
captures handle pre-compaction (FlashAttention) and post-compaction (SDPA+beta) phases.
Beta training hooks (`compaction.py`) inject matching bias into attention_mask during FSDP2
training for consistency.

**Suffix queries**: When `use_suffix_queries=true`, the compaction algorithm uses real query
vectors from suffix tokens instead of random Gaussian probes. A prefill pass re-runs the
suffix through the model with hooks on vLLM's inner `Attention` class to capture post-RoPE
queries at every layer. Model-agnostic (works for any architecture using vLLM's Attention).
+3% accuracy over random probes at ~20% slower wall time (extra prefill per compaction event).
**Markovian mode**: When `compaction_mode="markovian"`, the window is hard-deleted instead
of compressed. The cache becomes `[prompt | suffix]` with no C1/C2. Supported in both
inference (`worker.py`) and trainer (`segmented_forward`). Config field `compaction_mode`
is auto-synced from env args to trainer config.

**Auto-batching**: Individual `/compact_generate` requests are transparently batched into
`compact_generate_batch` calls (B=8) by `_RequestBatcher` in routes.py.
Expand All @@ -84,22 +60,29 @@ queries at every layer. Model-agnostic (works for any architecture using vLLM's
- CUDA graphs: only with TP=1; batch mode uses two-phase capture (pre/post compaction)
- `empty_cache()` between compaction segments in trainer prevents CUDA fragmentation OOM
- AC disable in segmented_forward is LoRA-only (Full FT keeps AC enabled)
- **Suffix queries + forced indices is default**: Inference uses suffix queries for
importance scoring and passes top-k indices to trainer. Do NOT regress to random-only.
- **Full-context softmax**: `compact_kv()` scores importance over all keys
(prompt + assistant). Do NOT regress to assistant-only scoring.
- **NNLS target must use K_all_h (all keys)**, not Kw_h (window-only). Using window-only
produces beta that's too small, causing 2.4x lower reward.

### Training (default: 2-node, 4+4 layout)
### Training

Node 1: 4 inference servers (TP=1, ports 8000-8003)
Node 2: 4 trainer GPUs (FSDP2) + orchestrator (CPU)
Default: 2-node mixed mode (5 inference + 3 trainer GPUs). See the `multinode` skill for
launch instructions.

```bash
sbatch -A m5017 -C "gpu&hbm80g" --qos=premium --time 48:00:00 --gpus-per-node 4 --nodes=2 ~/compaction_suffix_queries.sh
sbatch -A m5017 -C "gpu&hbm80g" --qos=premium --time 24:00:00 --gpus-per-node 4 --nodes=2 ~/compaction_multinode.sh
```

### RSA (Recursive Self-Aggregation)

`rsa_generate` on `CompactionWorker` implements RSA V2 with persistent compacted memory.
Prefills the question, forks KV into K candidates, generates a population, then iteratively:
selects peers, builds aggregation prompt, append-prefills onto base KV, generates probe
tokens for attention patterns, compacts the aggregation region, and generates new candidates.

Key helpers: `_fork_kv_blocks` (block-level KV copy), `_prefill_append` (chunked prefill
onto existing KV), `_batch_generate` (K candidates in parallel), `_inject_compacted_range`
(range-based KV injection), `compact_kv_range` in algorithm.py (range-based compaction).

Endpoint: `/rsa_generate` in routes.py. No auto-batching (RSA uses full GPU internally).

### Running evals

```bash
Expand All @@ -110,9 +93,24 @@ python scripts/eval_rg_mix.py --mode compaction --n 100 \

python scripts/eval_rg_mix.py --mode baseline --n 100

# With suffix queries:
# RSA mode
python scripts/eval_rg_mix.py --mode rsa --n 100 \
--rsa-K 4 --rsa-T 2 --rsa-k-peers 2 --rsa-probe-tokens 512

# Markovian mode
python scripts/eval_rg_mix.py --mode compaction --n 100 \
--max-kv-len 2048 --max-total-tokens 8192 \
--n-compacts 99 --compact-ratio 0.25 --compact-window 1024 \
--use-suffix-queries
--n-compacts 99 --compact-window 1024 \
--compaction-mode markovian

# AIME benchmark
python scripts/eval_aime_rsa.py --mode rsa --n 30 --rsa-K 4 --rsa-T 2
python scripts/eval_aime_rsa.py --mode baseline --n 30

# BabyAI multi-turn grid-world (via minigrid)
python scripts/eval_balrog_babyai.py --mode baseline --n 10
python scripts/eval_balrog_babyai.py --mode compaction --n 10 \
--max-kv-len 2048 --compact-ratio 0.25
python scripts/eval_balrog_babyai.py --mode markovian --n 10 \
--max-kv-len 2048
```
62 changes: 62 additions & 0 deletions configs/compaction/qwen3_06b_markovian_test.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Compaction Markovian E2E Test — Qwen3-0.6B (single-node, 4 GPU)
# 2 inference (TP=1, DP=2) + 2 trainer (FSDP2)
# Markovian mode: hard-deletes compaction window, no attention matching.

max_steps = 20
seq_len = 4096
output_dir = "../scratch/ckpts/compaction-markovian-test"

[model]
name = "Qwen/Qwen3-0.6B"

[wandb]
project = "compaction-rl"
name = "compaction-markovian-test-06b"

[ckpt]
interval = 10
resume_step = -1

[deployment]
num_train_gpus = 2
num_infer_gpus = 0

[trainer]
dist_timeout_seconds = 3600
compact_target_ratio = 0.25
compact_window = 1024
compaction_mode = "markovian"

[trainer.model]
impl = "auto"
optim_cpu_offload = true

[trainer.model.ac]
freq = 1

[trainer.optim]
lr = 1e-6
weight_decay = 0.01
betas1 = 0.9
betas2 = 0.9

[trainer.loss]
type = "default"

[orchestrator]
batch_size = 16
rollouts_per_example = 2

[orchestrator.client]
base_url = [
"http://localhost:8000/v1",
"http://localhost:8001/v1",
]

[orchestrator.sampling]
max_tokens = 3072

[[orchestrator.env]]
id = "compaction_env"
name = "compaction-math"
args = { gym = "math_env", max_seq_len = 4096, max_kv_len = 2048, max_total_tokens = 3072, compact_target_ratio = 0.25, compact_window = 1024, n_compacts = 99, use_suffix_queries = false, compaction_mode = "markovian" }
14 changes: 14 additions & 0 deletions configs/compaction/qwen3_06b_serve_tp1.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
gpu_memory_utilization = 0.95

[server]
host = "0.0.0.0"

[model]
name = "Qwen/Qwen3-0.6B"
enforce_eager = true

[parallel]
tp = 1

[vllm_extra]
enable_compaction = true
70 changes: 70 additions & 0 deletions configs/compaction/qwen3_4b_balrog_babyai.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# BabyAI Training — Qwen3-4B (baseline, multi-turn via balrog-bench)
# Grid-world RL: model navigates, picks up objects, unlocks doors.
# 2-node: Node 1 = 4 inference (TP=1), Node 2 = 4 trainer (FSDP2)
#
# Requires: `uv run prime env install prime-community/balrog-bench`
# plus the full BALROG framework (cmake + NLE) for BalrogEnv.
#
# For compaction evaluation (no training), use:
# python scripts/eval_balrog_babyai.py --mode compaction --n 10
# which uses minigrid directly (lighter, no NLE needed).
#
# __INFERENCE_NODE__ is replaced by the launch script.

max_steps = 200
seq_len = 4096
output_dir = "outputs/balrog-babyai"

[model]
name = "Qwen/Qwen3-4B"

[wandb]
project = "balrog-rl"
name = "balrog-babyai-baseline"

[ckpt]
interval = 10
resume_step = -1

[deployment]
num_train_gpus = 4
num_infer_gpus = 0

[trainer]
dist_timeout_seconds = 3600

[trainer.model]
impl = "auto"
optim_cpu_offload = true

[trainer.model.ac]
freq = 1

[trainer.optim]
lr = 1e-6
weight_decay = 0.01
betas1 = 0.9
betas2 = 0.9

[trainer.loss]
type = "default"

[orchestrator]
batch_size = 64
rollouts_per_example = 4

[orchestrator.client]
base_url = [
"http://__INFERENCE_NODE__:8000/v1",
"http://__INFERENCE_NODE__:8001/v1",
"http://__INFERENCE_NODE__:8002/v1",
"http://__INFERENCE_NODE__:8003/v1",
]

[orchestrator.sampling]
max_tokens = 4096

[[orchestrator.env]]
id = "balrog-bench"
name = "balrog-babyai"
args = { environments = ["babyai"], max_text_history = 16 }
5 changes: 5 additions & 0 deletions configs/compaction/qwen3_4b_baseline_tp1.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[model]
name = "Qwen/Qwen3-4B"

[parallel]
tp = 1
78 changes: 78 additions & 0 deletions configs/compaction/qwen3_4b_markovian_mila.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Markovian Compaction Training — Qwen3-4B (Mila, 4xA100)
# 2 inference (TP=1, DP=2) + 2 trainer (FSDP2)
# Markovian mode: hard-deletes compaction window, no attention matching.
# rg_mix_env with reasoning_gym scoring.

max_steps = 600
seq_len = 9216
output_dir = "../scratch/outputs/compaction_markovian_mila"

[model]
name = "Qwen/Qwen3-4B"

[wandb]
project = "compaction-rl"
name = "compaction-markovian-4b-mila"

[trainer.wandb]
id = "zn4ybskm"

[orchestrator.wandb]
id = "texk5786"

[ckpt]
interval = 15
resume_step = -1
keep_last = 1

[deployment]
num_train_gpus = 2
num_infer_gpus = 2

[trainer]
dist_timeout_seconds = 3600
compact_target_ratio = 0.25
compact_window = 1024
compaction_mode = "markovian"

[trainer.model]
impl = "auto"
optim_cpu_offload = true

[trainer.model.ac]
freq = 1

[trainer.optim]
lr = 1e-6
weight_decay = 0.01
betas1 = 0.9
betas2 = 0.9

[trainer.loss]
type = "default"

[orchestrator]
batch_size = 256
rollouts_per_example = 8
use_token_client = false

[orchestrator.sampling]
temperature = 1.0
max_tokens = 8192

[[orchestrator.env]]
id = "compaction_env"
name = "compaction-rg-mix"
args = { gym = "rg_mix_env", max_seq_len = 9216, max_kv_len = 2048, max_total_tokens = 8192, compact_target_ratio = 0.25, compact_window = 1024, n_compacts = 99, use_suffix_queries = false, compaction_mode = "markovian", dataset_path = "../scratch/datasets/rg_mix_7500", num_train_examples = 7500, num_eval_examples = 100 }

[orchestrator.verification]
enabled = false

[inference]
gpu_memory_utilization = 0.95

[inference.model]
enforce_eager = true

[inference.vllm_extra]
enable_compaction = true
Loading