Skip to content

Commit f7d9e92

Browse files
committed
Non-Record: 11L Parallel Muon + LN Scale + LeakyReLU² MLP3x + Legal TTT (val_bpb 1.1215, 3-seed)
1 parent 8e62acb commit f7d9e92

6 files changed

Lines changed: 753 additions & 761 deletions

File tree

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
1-
# Record: 11L Parallel Muon + LeakyReLU² MLP3x + Legal Score-First TTT
1+
# Non-Record: 11L Parallel Muon + LN Scale + LeakyReLU² MLP3x + Legal TTT
22

3-
**3-seed mean val_bpb: 1.1253** (std=0.0002) | **~15 MB** | 8xH100 SXM
3+
**3-seed mean val_bpb: 1.1215** (std=0.0002) | **~15.85 MB** | 8xH100 SXM
44

55
## 3-Seed Results (8xH100 80GB SXM, PyTorch 2.9.1+cu128)
66

77
| Seed | step_avg | steps | EMA bpb | Quantized bpb | **TTT bpb** |
88
|------|----------|-------|---------|---------------|-------------|
9-
| 1337 | 91.5ms | 6,556 | 1.1194 | 1.1291 | **1.1255** |
10-
| 42 | 89.2ms | 6,726 | 1.1195 | 1.1278 | **1.1253** |
11-
| 2024 | 89.3ms | 6,722 | 1.1193 | 1.1280 | **1.1251** |
12-
| **Mean** | **90.0ms** | **6,668** | **1.1194** | **1.1283** | **1.1253** |
9+
| 1337 | 88.8ms | 6,759 | 1.1161 | 1.1238 | **1.1217** |
10+
| 42 | 88.8ms | 6,757 | 1.1158 | 1.1234 | **1.1213** |
11+
| 2024 | 88.9ms | 6,752 | 1.1160 | 1.1234 | **1.1215** |
12+
| **Mean** | **88.8ms** | **6,756** | **1.1160** | **1.1235** | **1.1215** |
1313

14-
## Architecture (29.8M parameters)
14+
## Architecture (26.8M parameters)
1515

1616
- 11 transformer layers, dim=512, 8 heads / 4 KV heads (GQA)
1717
- **Parallel Muon** with parameter banking (4 contiguous 3D banks) + batched Newton-Schulz
1818
- MLP 3x expansion (hidden=1536) with **LeakyReLU(0.5)²** activation
19+
- **LN Scale** — depth-dependent normalization: 1/sqrt(layer_idx+1)
1920
- **SmearGate** + **BigramHash(1536, dim=128)**
2021
- **Value Residual (ResFormer)** — cache V from layer 0, blend via learned lambda
2122
- **Gated Attention** — per-head sigmoid gate (nn.Linear, bias init 4.0)
@@ -26,39 +27,40 @@
2627

2728
## Training
2829

29-
- **Parallel Muon optimizer**: 3-phase async reduce-scatter Adam NS5+all-gather
30-
- lr=0.025, momentum 0.920.99/1500 steps, WD=0.04
31-
- No DDP manual gradient sync for non-bank params
30+
- **Parallel Muon optimizer**: 3-phase async reduce-scatter -> Adam -> NS5+all-gather
31+
- lr=0.025, momentum 0.92->0.99/1500 steps, WD=0.04
32+
- No DDP -- manual gradient sync for non-bank params
3233
- Adam for embeddings (lr=0.035) and scalars (lr=0.025)
3334
- Batch 786,432 tokens, seq_len 2048
3435
- EMA (decay=0.997) + SWA (every 50 steps when scale < 0.2)
3536
- Warmdown 3500 iterations (wallclock-based)
36-
- Late QAT via STE (final 15% of wallclock), symmetric [-31, 31] range
37+
- Late QAT via STE (final 15% of wallclock)
3738
- Gradient clipping 0.3
38-
- torch.compile(fullgraph=True) — no DDP wrapper for maximum compilation
39+
- torch.compile(fullgraph=True)
3940

4041
## Quantization
4142

4243
- Int6 uniform per-row with GPTQ-lite (5-percentile clip search per row)
4344
- FP16 passthrough for tied embeddings
4445
- zstd-22 compression
45-
- Unbank quantize rebank for compatibility with parameter banking
46+
- Unbank -> quantize -> rebank for compatibility with parameter banking
4647

4748
## Legal Score-First TTT (PR #461 / #549 recipe)
4849

4950
Every token scored BEFORE any weight update:
5051

5152
```
5253
for each 32K-token chunk:
53-
Phase 1 SCORE: sliding window eval (inference_mode, stride=64)
54-
Phase 2 TRAIN: SGD(lr=0.002, momentum=0.9), 3 epochs, all blocks unfrozen, cosine LR
54+
Phase 1 -- SCORE: sliding window eval (inference_mode, stride=64)
55+
Phase 2 -- TRAIN: SGD(lr=0.002, momentum=0.9), 3 epochs, all blocks unfrozen, cosine LR
5556
```
5657

57-
TTT improves quantized BPB by ~0.003 (1.1283 → 1.1253).
58+
TTT improves quantized BPB by ~0.002 (1.1235 -> 1.1215).
5859

5960
## Credits
6061

6162
- Parallel Muon / Parameter Banking: PR #399 by @abaybektursun
6263
- LeakyReLU²: PR #493 by @parinzee, PR #518 by @sofiabod
64+
- LN Scale: PR #315/374 by @jfprincz
6365
- TTT recipe: PR #461 by @Christopher-Lee-McClendon (adapted: freeze=0)
6466
- Base model stack: PR #414 by @signalrush
Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
{
22
"author": "Aryan Bhosale",
33
"github_id": "aryanbhosale",
4-
"name": "11L Parallel Muon + LeakyReLU² MLP3x + Legal Score-First TTT (mean val_bpb=1.1253)",
5-
"blurb": "11-layer 512d transformer with Parallel Muon (parameter banking + batched NS5), LeakyReLU(0.5)² MLP 3x, SmearGate, BigramHash(1536), Value Residual, Gated Attention, XSA4, Partial RoPE(16/64), EMA(0.997)+SWA, Late QAT, GPTQ-lite int6+zstd-22, legal score-first TTT (SGD momentum=0.9, lr=0.002, 3 epochs). 3-seed mean 1.1253 BPB (std=0.0002) on 8xH100 SXM.",
6-
"date": "2026-03-25T12:00:00Z",
7-
"val_loss": 1.9000,
8-
"val_bpb": 1.1253,
9-
"bytes_total": 15000000,
10-
"bytes_code": 78438,
4+
"name": "11L Parallel Muon + LN Scale + LeakyReLU² MLP3x + Legal TTT (mean val_bpb=1.1215)",
5+
"blurb": "11-layer 512d transformer with Parallel Muon (parameter banking + batched NS5), LN Scale, LeakyReLU(0.5)² MLP 3x, SmearGate, BigramHash(1536), Value Residual, Gated Attention, XSA4, Partial RoPE(16/64), EMA(0.997)+SWA, Late QAT, GPTQ-lite int6+zstd-22, legal score-first TTT. 3-seed mean 1.1215 BPB (std=0.0002) on 8xH100 SXM.",
6+
"date": "2026-03-26T12:00:00Z",
7+
"val_loss": 1.8937,
8+
"val_bpb": 1.1215,
9+
"bytes_total": 15850000,
10+
"bytes_code": 80000,
1111
"seeds": {
12-
"1337": {"val_bpb": 1.1255, "val_loss": 1.9004, "steps": 6556, "step_avg_ms": 91.5},
13-
"42": {"val_bpb": 1.1253, "val_loss": 1.8999, "steps": 6726, "step_avg_ms": 89.2},
14-
"2024": {"val_bpb": 1.1251, "val_loss": 1.8997, "steps": 6722, "step_avg_ms": 89.3}
12+
"1337": {"val_bpb": 1.1217, "val_loss": 1.8940, "steps": 6759, "step_avg_ms": 88.8},
13+
"42": {"val_bpb": 1.1213, "val_loss": 1.8933, "steps": 6757, "step_avg_ms": 88.8},
14+
"2024": {"val_bpb": 1.1215, "val_loss": 1.8937, "steps": 6752, "step_avg_ms": 88.9}
1515
}
1616
}

records/track_10min_16mb/2026-03-25_11L_ParallelMuon_MLP3x_TTT/train_gpt.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
"""SOTA config: Parallel Muon + parameter banks + MLP 3x + 11L XSA + LN Scale + VE128 + GPTQ-lite int6 + LZMA."""
1+
"""SOTA config: Parallel Muon + parameter banks + MLP 3x + 11L XSA + LN Scale + GPTQ-lite int6 + zstd-22."""
22
from __future__ import annotations
33
import copy
44
import glob
55
import io
6-
import lzma
76
import math
87
import os
98
import random
@@ -16,20 +15,11 @@
1615

1716
try:
1817
import zstandard
18+
def _compress(data: bytes) -> bytes: return zstandard.ZstdCompressor(level=22).compress(data)
19+
def _decompress(data: bytes) -> bytes: return zstandard.ZstdDecompressor().decompress(data)
1920
except ImportError:
20-
zstandard = None
21-
22-
def _compress(data: bytes) -> bytes:
23-
return lzma.compress(data, preset=6)
24-
25-
def _decompress(data: bytes) -> bytes:
26-
try:
27-
return lzma.decompress(data)
28-
except Exception:
29-
try:
30-
return zstandard.ZstdDecompressor().decompress(data)
31-
except Exception:
32-
return zlib.decompress(data)
21+
def _compress(data: bytes) -> bytes: return zlib.compress(data, level=9)
22+
def _decompress(data: bytes) -> bytes: return zlib.decompress(data)
3323

3424
import numpy as np
3525
import sentencepiece as spm
@@ -95,11 +85,11 @@ class Hyperparameters:
9585

9686
bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 1536))
9787
bigram_dim = int(os.environ.get("BIGRAM_DIM", 128))
98-
xsa_last_n = int(os.environ.get("XSA_LAST_N", 11))
88+
xsa_last_n = int(os.environ.get("XSA_LAST_N", 4))
9989
rope_dims = int(os.environ.get("ROPE_DIMS", 16))
10090
ln_scale = bool(int(os.environ.get("LN_SCALE", "1")))
101-
ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1")))
102-
ve_dim = int(os.environ.get("VE_DIM", 64))
91+
ve_enabled = bool(int(os.environ.get("VE_ENABLED", "0")))
92+
ve_dim = int(os.environ.get("VE_DIM", 32))
10393
ve_layers = os.environ.get("VE_LAYERS", "9,10")
10494

10595
use_smeargate = bool(int(os.environ.get("USE_SMEARGATE", "1")))

0 commit comments

Comments
 (0)