Skip to content

Commit ca19512

Browse files
committed
Add official-template-safe 11L XSA+EMA submission
1 parent fc6332a commit ca19512

6 files changed

Lines changed: 3443 additions & 0 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Happy training!
3030

3131
| Run | Score | Author | Summary | Date | Info |
3232
|-----|------:|--------|---------|------|------|
33+
| 11L XSA4 + EMA + Batch524K | 1.1357 | dennisimoo | 11 layers, XSA on last 4 layers, EMA, batch 524288, official-template-safe zstd fallback | 2026-03-21 | [info](records/track_10min_16mb/2026-03-21_11L_XSA4_EMA_B524K_ZstdFallback/README.md) |
3334
| 10L Int5-MLP + BigramHash(10240) | 1.1428 | thwu1 | 10 layers, mixed int5/int6 quantization, BigramHash(10240), SWA(0.4), WD=0.04 | 2026-03-20 | [info](records/track_10min_16mb/2026-03-20_10L_Int5MLP_MuonWD04_SWA50/README.md) |
3435
| Int6 MLP3x + SmearGate + BigramHash | 1.1458 | Raahil Shah | 3x MLP + SmearGate + BigramHash + OrthoInit + Muon WD + SWA | 2026-03-20 | [info](records/track_10min_16mb/2026-03-20_Int6_MLP3x_SmearGate_BigramHash_MuonWD_SWA/README.md) |
3536
| 11L MLP3x + Int6 QAT | 1.1502 | aruniyer | 11 layers, 3x MLP, int6 QAT, zstd-22, WD=0.04, sliding eval | 2026-03-20 | [info](records/track_10min_16mb/2026-03-19_MLP3x_QAT_Int6_SlidingWindow/README.md) |
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Record: 11L XSA4 + EMA + Batch524K + zstd Fallback
2+
3+
**val_bpb = 1.1357** (sliding window, stride=64) | **15.67 MB** artifact | 8xH100 SXM, ~600s
4+
5+
Single-seed submission using an 11-layer int6 MLP3x model with XSA on the last 4 layers, EMA averaging, SmearGate, BigramHash, and a 524K fixed-time batch setting.
6+
7+
## Result
8+
9+
| Metric | Value |
10+
|--------|-------|
11+
| Pre-quant val_bpb | 1.1529 |
12+
| Int6 roundtrip val_bpb | 1.1580 |
13+
| **Int6 sliding val_bpb (stride 64)** | **1.1357** |
14+
| Model bytes (int6+zstd) | 15,603,062 |
15+
| Code bytes | 66,891 |
16+
| **Total submission bytes** | **15,669,953** |
17+
18+
This is below the current merged README SOTA (`1.1428`) but it is not a 3-seed validated record claim.
19+
20+
## What's New
21+
22+
| Change | Impact |
23+
|--------|--------|
24+
| `TRAIN_BATCH_TOKENS=524288` | Better fixed-budget step count than the larger-batch 11-layer XSA+EMA setting |
25+
| SDPA fallback for `flash_attn_interface` | Runs cleanly when FA3 Python bindings are unavailable in the official image |
26+
| `torch.compile` behind an env flag | Reliable eager smoke tests, faster compiled full run |
27+
| `zstd` Python-or-CLI fallback | Keeps int6 export under 16MB without depending on a specific Python package in the image |
28+
29+
## Configuration
30+
31+
```bash
32+
NUM_LAYERS=11 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=3 \
33+
TRAIN_BATCH_TOKENS=524288 TRAIN_SEQ_LEN=2048 EVAL_SEQ_LEN=2048 EVAL_STRIDE=64 \
34+
BIGRAM_VOCAB_SIZE=2048 BIGRAM_DIM=128 \
35+
XSA_LAST_N=4 EMA_ENABLED=1 EMA_DECAY=0.997 SWA_ENABLED=0 TTT_ENABLED=0 \
36+
MUON_WD=0.04 ADAM_WD=0.04 MATRIX_LR=0.025 SCALAR_LR=0.025 TIED_EMBED_LR=0.035 \
37+
MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_START=0.92 MUON_MOMENTUM_WARMUP_STEPS=1500 \
38+
WARMDOWN_ITERS=3000 WARMUP_STEPS=20 ENABLE_TORCH_COMPILE=1 \
39+
MAX_WALLCLOCK_SECONDS=600 torchrun --standalone --nproc_per_node=8 train_gpt.py
40+
```
41+
42+
## Key Run Details
43+
44+
| Metric | Value |
45+
|--------|-------|
46+
| Steps reached | 8,202 |
47+
| Average train step time | 73.37 ms |
48+
| Peak memory allocated | 13,879 MiB |
49+
| Peak memory reserved | 14,004 MiB |
50+
| Final eval mode | Sliding window, stride 64 |
51+
52+
## Included Files
53+
54+
- `train_gpt.py` — training, export, and eval script
55+
- `run_hybrid_attempt.sh` — launch wrapper used for the run
56+
- `train.log` — full log from the validated 600s attempt
57+
- `submission.json` — metadata for the submission
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#!/usr/bin/env bash
2+
set -euo pipefail
3+
4+
SEED="${SEED:-1337}"
5+
NPROC_PER_NODE="${NPROC_PER_NODE:-8}"
6+
7+
export RUN_ID="${RUN_ID:-hybrid_xsa_ema_safe}"
8+
export SEED
9+
export NUM_LAYERS="${NUM_LAYERS:-11}"
10+
export MODEL_DIM="${MODEL_DIM:-512}"
11+
export NUM_HEADS="${NUM_HEADS:-8}"
12+
export NUM_KV_HEADS="${NUM_KV_HEADS:-4}"
13+
export MLP_MULT="${MLP_MULT:-3}"
14+
export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-2048}"
15+
export EVAL_SEQ_LEN="${EVAL_SEQ_LEN:-2048}"
16+
export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}"
17+
export EVAL_STRIDE="${EVAL_STRIDE:-64}"
18+
export BIGRAM_VOCAB_SIZE="${BIGRAM_VOCAB_SIZE:-2048}"
19+
export BIGRAM_DIM="${BIGRAM_DIM:-128}"
20+
export MATRIX_LR="${MATRIX_LR:-0.025}"
21+
export SCALAR_LR="${SCALAR_LR:-0.025}"
22+
export TIED_EMBED_LR="${TIED_EMBED_LR:-0.035}"
23+
export MUON_MOMENTUM="${MUON_MOMENTUM:-0.99}"
24+
export MUON_MOMENTUM_WARMUP_START="${MUON_MOMENTUM_WARMUP_START:-0.92}"
25+
export MUON_MOMENTUM_WARMUP_STEPS="${MUON_MOMENTUM_WARMUP_STEPS:-1500}"
26+
export MUON_WD="${MUON_WD:-0.04}"
27+
export ADAM_WD="${ADAM_WD:-0.04}"
28+
export WARMDOWN_ITERS="${WARMDOWN_ITERS:-3000}"
29+
export WARMUP_STEPS="${WARMUP_STEPS:-30}"
30+
export ITERATIONS="${ITERATIONS:-9000}"
31+
export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}"
32+
export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-0}"
33+
export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-200}"
34+
export XSA_LAST_N="${XSA_LAST_N:-4}"
35+
export EMA_ENABLED="${EMA_ENABLED:-1}"
36+
export EMA_DECAY="${EMA_DECAY:-0.997}"
37+
export SWA_ENABLED="${SWA_ENABLED:-0}"
38+
export TTT_ENABLED="${TTT_ENABLED:-0}"
39+
export ENABLE_TORCH_COMPILE="${ENABLE_TORCH_COMPILE:-0}"
40+
41+
torchrun --standalone --nproc_per_node="${NPROC_PER_NODE}" train_gpt.py
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"author": "dennisimoo",
3+
"github_id": "dennisimoo",
4+
"name": "Record: 11L XSA4 + EMA + Batch524K + zstd fallback",
5+
"blurb": "11-layer int6 MLP3x model with SmearGate, BigramHash(2048x128), XSA on the last 4 layers, EMA(0.997), WD=0.04, batch 524288, sliding-window eval stride 64, SDPA fallback when FA3 is unavailable, and official-template-safe compression via Python zstandard or zstd CLI fallback.",
6+
"date": "2026-03-21T00:00:00Z",
7+
"val_loss": 1.91760887,
8+
"val_bpb": 1.13571899,
9+
"pre_quant_val_loss": 1.9465,
10+
"pre_quant_val_bpb": 1.1529,
11+
"int6_zstd_val_loss": 1.95527145,
12+
"int6_zstd_val_bpb": 1.15802189,
13+
"bytes_total": 15669953,
14+
"bytes_model_int6_zstd": 15603062,
15+
"bytes_code": 66891
16+
}

0 commit comments

Comments
 (0)