Skip to content

Commit a4a7f76

Browse files
committed
exp_074_prequant_ttt: Pre-quant AdamW TTT (ready, untested)
Runs AdamW TTT on the full-precision EMA model BEFORE GPTQ quantization. Based on PR openai#1364 which reports -0.027 BPB from this technique alone. Flow: Train -> EMA -> AdamW TTT (3 epochs, freeze 2 blocks) -> GPTQ -> eval Key fix: destroy_process_group + reinit pattern to avoid NCCL watchdog timeout during the ~13-min single-rank TTT phase. Standard dist.barrier() is insufficient because NCCL's heartbeat thread times out independently. Env: PREQUANT_TTT_ENABLED=1 PREQUANT_TTT_EPOCHS=3 PREQUANT_TTT_LR=3e-4 PREQUANT_TTT_FREEZE_BLOCKS=2
1 parent a8412e3 commit a4a7f76

3 files changed

Lines changed: 2555 additions & 0 deletions

File tree

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# exp_074_prequant_ttt — Pre-quant AdamW TTT (READY, untested)
2+
3+
**Hypothesis**: Running AdamW TTT on the **full-precision EMA model before GPTQ** should give a much larger BPB improvement than post-quant SGD TTT.
4+
5+
**Source**: [PR #1364](https://github.com/openai/parameter-golf/pull/1364) reports −0.027 BPB from this technique alone (1.1025 BPB 3-seed mean).
6+
7+
## Why this works
8+
9+
Post-quant SGD TTT on int6 weights is unstable — we observed +0.030 BPB
10+
penalty with naive SGD TTT on a GPTQ-quantized model (see PR #756's
11+
25 failed attempts). Running TTT **before** quantization:
12+
13+
1. Avoids optimizer instability on the quantized weight manifold
14+
2. Lets GPTQ see the TTT-adapted Hessians during calibration
15+
3. Uses AdamW (not SGD) for better adaptation dynamics
16+
17+
## Flow
18+
19+
```
20+
Train 600s → EMA model (bf16)
21+
→ AdamW TTT on full-precision model (3 epochs)
22+
→ GPTQ quantize the adapted model
23+
→ Sliding window eval (no further TTT)
24+
```
25+
26+
## NCCL Timeout Fix
27+
28+
Pre-quant TTT runs for ~13 minutes on rank 0 only, exceeding NCCL's
29+
default watchdog timeout (600s). Fix:
30+
31+
```python
32+
if distributed:
33+
dist.barrier()
34+
dist.destroy_process_group()
35+
# ... rank 0 runs TTT ...
36+
if distributed:
37+
dist.init_process_group(backend="nccl", device_id=device)
38+
for p in base_model.parameters():
39+
dist.broadcast(p.data, src=0)
40+
```
41+
42+
## Running
43+
44+
```bash
45+
PREQUANT_TTT_ENABLED=1 PREQUANT_TTT_EPOCHS=3 PREQUANT_TTT_LR=3e-4 \
46+
PREQUANT_TTT_FREEZE_BLOCKS=2 GPTQ_ENABLED=1 GPTQ_N_BATCHES=64 \
47+
TTT_ENABLED=0 EVAL_STRIDE=64 SEED=1337 \
48+
torchrun --standalone --nproc_per_node=8 train_gpt.py
49+
```
50+
51+
## Expected Result
52+
53+
Targeting ~1.10-1.12 BPB (a −0.01 to −0.027 BPB gain from the post-EMA 1.142).
54+
Full PR #1364 reports 1.1025 at 6 epochs; we use 3 epochs to halve the TTT time.
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#!/bin/bash
2+
# Parameter Golf - exp_074: Pre-quant AdamW TTT
3+
# Requirements: 8xH100 SXM, PyTorch 2.x, CUDA 12.x
4+
# Expected time: ~35 min total (12 min FA3 build + 10 min train + 13 min TTT + 5 min GPTQ/eval)
5+
set -e
6+
7+
echo "=== Step 1: Install dependencies ==="
8+
pip install tiktoken blobfile tqdm lm_eval sentencepiece 2>/dev/null
9+
10+
echo "=== Step 2: Build Flash Attention 3 (Hopper kernels) ==="
11+
echo "This takes ~12 minutes. DO NOT SKIP - FA3 gives ~86ms/step vs ~100ms with FA2."
12+
pip install flash-attn --no-build-isolation 2>&1 | tail -5
13+
# Verify FA3
14+
python -c "from flash_attn_interface import flash_attn_func; print('FA3 OK')" 2>/dev/null \
15+
|| python -c "from flash_attn import flash_attn_func; print('FA2 fallback (slower)')"
16+
17+
echo "=== Step 3: Download training data ==="
18+
# The script auto-downloads data, but we can pre-fetch for speed
19+
python -c "
20+
import subprocess, os
21+
os.makedirs('data', exist_ok=True)
22+
if not os.path.exists('data/cached_challenge_fineweb.py'):
23+
subprocess.run(['wget', '-q', '-O', 'data/cached_challenge_fineweb.py',
24+
'https://raw.githubusercontent.com/openai/parameter-golf/main/data/cached_challenge_fineweb.py'])
25+
" 2>/dev/null
26+
python data/cached_challenge_fineweb.py --variant sp1024 --train-shards 80 2>&1 | tail -3
27+
28+
echo "=== Step 4: Run experiment ==="
29+
echo "Training 600s → EMA → Pre-quant AdamW TTT (3 epochs) → GPTQ → Sliding eval"
30+
PREQUANT_TTT_ENABLED=1 \
31+
PREQUANT_TTT_EPOCHS=3 \
32+
PREQUANT_TTT_LR=3e-4 \
33+
PREQUANT_TTT_FREEZE_BLOCKS=2 \
34+
GPTQ_ENABLED=1 \
35+
GPTQ_N_BATCHES=64 \
36+
TTT_ENABLED=0 \
37+
EVAL_STRIDE=64 \
38+
SEED=1337 \
39+
torchrun --standalone --nproc_per_node=8 train_gpt.py 2>&1 | tee exp074_results.log
40+
41+
echo "=== Done! Check exp074_results.log for val_bpb ==="
42+
grep -E "val_bpb|DIAGNOSTIC|prequant_ttt|sliding" exp074_results.log | tail -20

0 commit comments

Comments
 (0)