Skip to content

Commit c8e034d

Browse files
committed
Record: Cosine TTT scheduling with per-layer lr (mean val_bpb=1.0970, 3 seeds)
AdamW TTT with cosine lr decay over 30 epochs and per-layer lr groups (3x for MLP output projections, 0.5x for input projections). 34 TTT configurations tested. FINDINGS.md documents 31 experiments including negative results on codebook quantization, symmetry-transport, layer dropping, focal loss, and KL divergence TTT. Builds on PRs openai#162, openai#180, openai#77, openai#398, openai#442, openai#417, openai#315.
1 parent 5353524 commit c8e034d

7 files changed

Lines changed: 9995 additions & 0 deletions

File tree

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Cosine TTT Scheduling with Per-Layer Learning Rates
2+
3+
Mean val_bpb = 1.0970 (3 seeds, std=0.0010) | 8×H100 SXM | 600s train + 465s TTT + 187s eval
4+
5+
## Results
6+
7+
| Seed | Steps | Pre-TTT | Post-TTT | Artifact |
8+
|------|-------|---------|----------|----------|
9+
| 1337 | 7,101 | 1.1577 | 1.0959 | 15.4 MB |
10+
| 42 | 6,700 | 1.1588 | 1.0971 | 15.5 MB |
11+
| 7 | 6,987 | 1.1580 | 1.0979 | 15.8 MB |
12+
13+
## Background
14+
15+
Starting from the community stack (PRs #162, #180, #315, #398), we spent several days exploring ways to improve compression and eval-time adaptation. Many of these did not improve the result but informed the direction that eventually worked.
16+
17+
### Compression research (did not improve score)
18+
19+
We analyzed trained checkpoints to evaluate alternative quantization and compression approaches:
20+
21+
- **Learned codebook quantization** (K-means, K=256): 87% lower reconstruction MSE than uniform int6, but 25% larger compressed artifact under zstd-22. Codebook indices have higher byte entropy than clamped int6 values.
22+
- **Symmetry-transport** (Procrustes alignment across layers): Layers share 91-93% rotational structure, but storing the rotation matrices costs more than storing the weights directly. Low-rank approximation of the rotation delta (rank-128) captured only 16.6% of variance.
23+
- **Embedding low-rank factorization** (SVD): Rank-64 explains 41.9% of variance on tok_emb (1024×512). Not viable at this vocabulary size.
24+
- **Magnitude pruning**: Non-monotonic interaction with zstd-22. 3% pruning increased artifact size by 728KB on our checkpoint.
25+
26+
These results indicated that int6+zstd is close to optimal for this model architecture and that compression was not the path to further improvement.
27+
28+
### Architectural exploration (did not improve score)
29+
30+
- **Progressive layer dropping**: Randomly skipping layers during training for regularization. Caused 0.06 BPB regression at step 1000 when combined with head dropout. The DDP implementation also introduced higher-order ops incompatible with torch.compile + DDPOptimizer.
31+
- **Depth recurrence** (Huginn-style, 3 shared blocks × 3 loops): Blocks learned position-specific functions rather than general refiners. Eval at 2× trained depth produced val_bpb 4.34. Not viable below ~100M params per unique layer.
32+
- **Neural cache** (cross-window KV caching at eval): Implemented but not validated on hardware. The original proposal (PR #318) was blocked by a torch.compile issue.
33+
34+
### TTT analysis (led to the finding)
35+
36+
Analyzing our trained checkpoint, we observed:
37+
38+
1. **Quantization error is uniformly distributed** — the top 1% of weights by error magnitude account for only 3.9% of total reconstruction error. This confirmed that outlier protection approaches would not be effective.
39+
2. **Quantization damage varies 3.4× across layer types** — MLP output projections (512×1536) have systematically higher relative error than input projections (1536×512).
40+
3. **TTT improvement exceeds quantization repair** — the TTT contribution (~0.06 BPB on our model) is roughly 2.4× larger than the quantization gap (~0.008), indicating TTT performs distribution adaptation beyond repairing quantization damage.
41+
42+
These observations motivated exploring the TTT schedule rather than the training architecture or compression scheme.
43+
44+
## TTT schedule
45+
46+
Two modifications to AdamW TTT (PR #442):
47+
48+
**Cosine lr decay** over 30 epochs instead of flat lr over 10 epochs. Quantization introduces both large-scale damage (outlier weight rounding) and distributed noise (small perturbations across all weights). A flat lr must compromise between these two regimes. Cosine decay applies full lr early to address large damage, then progressively reduces to refine without overshooting.
49+
50+
**Per-layer lr groups** based on the quantization damage measurements above. MLP output projections receive 3× base lr, input projections 0.5×, all other parameters 1×. This allocates more adaptation capacity to more damaged layers. The ratios are specific to our model — other architectures may show different damage profiles.
51+
52+
We tested 34 TTT configurations across optimizers (AdamW, Adam, SGD), learning rates (1e-4 to 2e-3), epoch counts (3 to 30), schedules (flat, cosine, warmup+cosine), per-layer groupings, freeze strategies, and loss functions (cross-entropy, focal loss γ=1-3, KL divergence from pre-quant model).
53+
54+
Focal loss did not improve over cross-entropy — hard tokens appear to be unpredictable rather than undertrained. KL divergence from the pre-quant model was less effective than cross-entropy — the pre-quant and post-quant models are similar enough that the KL signal is weak relative to the cross-entropy signal from the validation data.
55+
56+
## TTT config
57+
58+
```
59+
TTT_OPTIMIZER=adamw TTT_LR=0.0005 TTT_EPOCHS=30
60+
TTT_COSINE=1 TTT_PERLAYER=1 TTT_FREEZE_BLOCKS=0
61+
TTT_BATCH_SEQS=64 (per GPU, 512 total with DDP sharding)
62+
```
63+
64+
Each GPU processes a contiguous 1/8 shard of the validation tokens with gradient all_reduce (ReduceOp.AVG). 30 epochs at ~15.5s/epoch = ~465s total.
65+
66+
## Training config
67+
68+
Standard community stack. 11L, 512d, 8H/4KV (GQA), 3x MLP (relu-squared), U-Net skips, SmearGate, BigramHash(2048), OrthoInit, Partial RoPE (16/64 dims), LN Scale, EMA(0.997), tied embeddings. XSA disabled. Int6 per-row + zstd-22.
69+
70+
## Notes
71+
72+
- All runs used FA2. FA3 Hopper would improve pre-TTT quality through faster training steps. The TTT schedule is independent of the attention kernel.
73+
- The cosine + per-layer schedule adds no artifact cost and minimal code complexity over flat-lr TTT.
74+
- See PR #212 for a non-record submission documenting 25+ additional experiments.
75+
76+
## Reproduction
77+
78+
```bash
79+
git clone https://github.com/mrdavtan/parameter-golf.git
80+
cd parameter-golf && git checkout next-gen
81+
pip install flash-attn --no-cache-dir --no-build-isolation
82+
pip install zstandard sentencepiece huggingface_hub
83+
python3 data/cached_challenge_fineweb.py --variant sp1024
84+
bash run_competition.sh 1337
85+
```
86+
87+
Hardware: 8×H100 SXM (RunPod), PyTorch 2.9.1+cu128, Flash Attention 2
88+
89+
Builds on PRs #162, #180, #77, #398, #442, #417, #315, and modded-nanogpt.
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
{
2+
"author": "mrdavtan",
3+
"github_id": "mrdavtan",
4+
"name": "Cosine TTT scheduling with per-layer lr (mean val_bpb=1.0970, 3 seeds)",
5+
"blurb": "AdamW TTT with cosine lr decay and per-layer lr groups. 30 epochs, 3x lr for MLP output projections, 0.5x for input projections.",
6+
"date": "2026-03-22",
7+
"val_loss": 1.8504,
8+
"val_bpb": 1.0959,
9+
"mean_val_bpb": 1.0970,
10+
"std_val_bpb": 0.0010,
11+
"seed": 1337,
12+
"num_seeds": 3,
13+
"seed_results": {
14+
"1337": 1.0959,
15+
"42": 1.0971,
16+
"7": 1.0979
17+
},
18+
"step_stop": 7101,
19+
"wallclock_seconds": 600.0,
20+
"ttt_time_seconds": 465.4,
21+
"eval_time_seconds": 186.5,
22+
"bytes_total": 15362557,
23+
"bytes_model_int8_zstd": 15258143,
24+
"bytes_code": 104414,
25+
"hardware": "8xH100 SXM (RunPod), PyTorch 2.9.1+cu128, FA2",
26+
"track": "track_10min_16mb",
27+
"model": {
28+
"num_layers": 11,
29+
"model_dim": 512,
30+
"num_heads": 8,
31+
"num_kv_heads": 4,
32+
"mlp_mult": 3,
33+
"vocab_size": 1024,
34+
"tie_embeddings": true,
35+
"total_params": 26829913
36+
},
37+
"ttt_config": {
38+
"optimizer": "adamw",
39+
"lr": 0.0005,
40+
"epochs": 30,
41+
"cosine": true,
42+
"perlayer": true,
43+
"perlayer_proj_mult": 3.0,
44+
"perlayer_fc_mult": 0.5,
45+
"freeze_blocks": 0,
46+
"batch_seqs_per_gpu": 64
47+
}
48+
}

0 commit comments

Comments
 (0)