Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# 11L Int6 + SmearGate + SWA + AdamW WD

**val_bpb: 1.1400** (3-seed mean, sliding window stride=64) | **15.7 MB** artifact | 8xH100 SXM, 600s

## Key Finding: Batch Size vs Step Count

The dominant factor in 10-minute training is not batch quality but total optimization steps. Reducing batch from 786K to 524K tokens:
- Drops step time from 91ms to 67ms (26% faster)
- Increases total steps from ~7,300 to ~8,900 (22% more)
- Despite seeing 12% fewer total tokens, the extra gradient updates improve convergence

This finding applies to any fixed-time training budget and suggests the optimal batch size is smaller than commonly assumed.

## Technique Stack

| Component | Choice | Rationale |
|-----------|--------|-----------|
| Layers | 11 | Extra depth funded by int6 + zstd compression headroom |
| MLP | 3x (1536) | Full width; int8 tok_emb + no Late-K saves space |
| Quantization | Int6 per-row (attention + MLP), int8 (tok_emb) | Int8 tok_emb preserves output projection quality |
| SmearGate | Per-dim, 512 params | Blends adjacent token embeddings |
| BigramHash | 2048 buckets, dim=128 | Consecutive token pair features |
| Weight decay | 0.04 (Muon + AdamW) | Dual WD shrinks weights for better quantization + compression |
| SWA | ~7 checkpoints, every 200 steps | Late-training weight averaging |
| OrthoInit | gain=1.0, proj scaled 1/sqrt(2L) | Standard orthogonal initialization |
| FlashAttention | v2.8.3 | ~3% throughput improvement over PyTorch SDPA |
| Compression | zstd level 22 | 35% better than zlib for int6-in-int8 data |
| Eval | Sliding window, stride=64, batch=32 | Batched windows make stride=64 feasible in 172s |

## Metrics

| Metric | Value |
|--------|-------|
| Sliding BPB (stride=64, 3-seed mean) | **1.1400** |
| Best single seed (1338) | **1.1381** |
| Artifact size | 15.7 MB |
| Steps (600s cap) | ~8,930 |
| Step time | 67ms |
| Model params | ~26.5M |

## Reproducibility (3 seeds)

| Seed | Sliding BPB | Artifact |
|------|-------------|----------|
| 1337 | 1.1411 | 15.95 MB |
| 1338 | 1.1381 | 15.63 MB |
| 1339 | 1.1408 | 15.66 MB |
| Mean | **1.1400** | 15.7 MB |
| Std | 0.0016 | — |

## Run Command

```bash
pip install zstandard flash-attn --no-build-isolation
SEED=1338 NUM_LAYERS=11 TRAIN_SEQ_LEN=2048 TRAIN_BATCH_TOKENS=524288 \
MLP_HIDDEN=1536 BIGRAM_VOCAB_SIZE=2048 BIGRAM_DIM=128 \
MATRIX_LR=0.025 SCALAR_LR=0.025 TIED_EMBED_LR=0.035 \
MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_START=0.92 \
MUON_MOMENTUM_WARMUP_STEPS=1500 WARMDOWN_ITERS=3000 GRAD_CLIP_NORM=0.3 \
EVAL_SEQ_LEN=2048 EVAL_STRIDE=64 EVAL_BATCH_SEQS=32 \
MUON_WD=0.04 ADAM_WD=0.04 SWA_FRAC=0.5 SWA_EVERY=200 \
torchrun --standalone --nproc_per_node=8 train_gpt.py
```

## Ablation Path (90+ experiments)

| Change | BPB | Delta |
|--------|-----|-------|
| Baseline (stock 9L) | 1.2244 | — |
| + int6 + MLP 3x + train@2048 + clip=0.3 (PR #114) | 1.1574 | -0.067 |
| + OrthoInit + MuonWD=0.02 | 1.1536 | -0.004 |
| + SmearGate + BigramHash + 10L | 1.1465 | -0.007 |
| + batch=524K (from 786K) | 1.1465 | +0.000 (same but more steps) |
| + 11L/1408, WD=0.039, FA | 1.1423 | -0.004 |
| + MLP=1536, LR=0.025, AdamW WD=0.04, int8 tok_emb | **1.1400** | **-0.002** |

## Dead Ends (selected from 90+ experiments)

- **QAT (int6 STE)**: 115ms/step overhead (vs 67ms baseline). Better quant quality but 25% fewer steps. Net loss.
- **Int5 for MLP**: Saves artifact space but 0.020 BPB quality penalty. Int6-all with tighter compression is better.
- **Batch=786K**: More tokens/step but fewer steps. 524K is optimal.
- **NorMuon**: 110ms/step. Throughput death.
- **MTP**: 86ms/step. Aux head too expensive.

## Previous Submissions

- PR #61: 1.2154 (warmdown-quantization discovery)
- PR #96: 1.1764 (sliding window + long-context training)
- PR #114: 1.1574 (int6 + MLP 3x + selective precision)
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"name": "11L Int6 + SmearGate + SWA + AdamW WD (val_bpb=1.1400)",
"val_loss": 1.9267,
"val_bpb": 1.1400,
"bytes_total": 15951384,
"blurb": "11 layers with 3x MLP (1536), int6 per-row for attention+MLP, int8 for tied embedding. SmearGate + BigramHash(2048). Orthogonal init, Muon+AdamW WD=0.04, SWA ~7 checkpoint average. FlashAttention 2.8.3. Sliding window eval stride=64 with batched windows. Key finding: smaller batch (524K vs 786K) gives 40% more optimization steps at lower per-step cost, beating larger batches for total convergence.",
"author": "Sam Larson",
"github_id": "saml212",
"date": "2026-03-20"
}
110 changes: 110 additions & 0 deletions records/track_10min_16mb/2026-03-20_11L_Int6_SmearGate_SWA/train.log
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
W0321 03:52:08.055000 411024 torch/distributed/run.py:803]
W0321 03:52:08.055000 411024 torch/distributed/run.py:803] *****************************************
W0321 03:52:08.055000 411024 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0321 03:52:08.055000 411024 torch/distributed/run.py:803] *****************************************
logs/2f87675d-0c37-4a2b-b135-ef2f522b65b8.txt
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
model_params:26829913
world_size:8 grad_accum_steps:1
sdp_backends:cudnn=False flash=True mem_efficient=False math=False
attention_mode:gqa num_heads:8 num_kv_heads:4
tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025
train_batch_tokens:524288 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000
seed:1338
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:6.9292 val_bpb:4.1038 train_time:0ms step_avg:0.01ms
step:1/20000 train_loss:6.9299 train_time:129ms step_avg:128.97ms
step:2/20000 train_loss:8.5876 train_time:199ms step_avg:99.40ms
step:3/20000 train_loss:7.7827 train_time:272ms step_avg:90.83ms
step:4/20000 train_loss:7.3472 train_time:346ms step_avg:86.55ms
step:5/20000 train_loss:6.8994 train_time:420ms step_avg:84.00ms
step:6/20000 train_loss:7.7853 train_time:494ms step_avg:82.40ms
step:7/20000 train_loss:6.9407 train_time:569ms step_avg:81.30ms
step:8/20000 train_loss:6.7671 train_time:643ms step_avg:80.37ms
step:9/20000 train_loss:6.5937 train_time:717ms step_avg:79.68ms
step:10/20000 train_loss:6.3448 train_time:791ms step_avg:79.13ms
step:200/20000 train_loss:2.7481 train_time:13826ms step_avg:69.13ms
step:400/20000 train_loss:2.2446 train_time:27583ms step_avg:68.96ms
step:600/20000 train_loss:2.4573 train_time:41385ms step_avg:68.97ms
step:800/20000 train_loss:2.2084 train_time:55206ms step_avg:69.01ms
step:1000/20000 train_loss:2.3072 train_time:69015ms step_avg:69.02ms
step:1000/20000 val_loss:2.2599 val_bpb:1.3384 train_time:69027ms step_avg:69.03ms
step:1200/20000 train_loss:2.3312 train_time:82818ms step_avg:69.01ms
step:1400/20000 train_loss:2.3683 train_time:96606ms step_avg:69.00ms
step:1600/20000 train_loss:2.0413 train_time:110400ms step_avg:69.00ms
step:1800/20000 train_loss:2.1449 train_time:124217ms step_avg:69.01ms
step:2000/20000 train_loss:2.1806 train_time:138066ms step_avg:69.03ms
step:2000/20000 val_loss:2.1652 val_bpb:1.2824 train_time:138078ms step_avg:69.04ms
step:2200/20000 train_loss:2.0035 train_time:151914ms step_avg:69.05ms
step:2400/20000 train_loss:2.1274 train_time:165770ms step_avg:69.07ms
step:2600/20000 train_loss:2.3560 train_time:179592ms step_avg:69.07ms
step:2800/20000 train_loss:2.1698 train_time:193422ms step_avg:69.08ms
step:3000/20000 train_loss:2.1621 train_time:207255ms step_avg:69.08ms
step:3000/20000 val_loss:2.1243 val_bpb:1.2581 train_time:207269ms step_avg:69.09ms
step:3200/20000 train_loss:2.1217 train_time:221087ms step_avg:69.09ms
step:3400/20000 train_loss:2.0934 train_time:234902ms step_avg:69.09ms
step:3600/20000 train_loss:2.0443 train_time:248726ms step_avg:69.09ms
step:3800/20000 train_loss:2.1490 train_time:262539ms step_avg:69.09ms
step:4000/20000 train_loss:2.1164 train_time:276368ms step_avg:69.09ms
step:4000/20000 val_loss:2.1071 val_bpb:1.2480 train_time:276381ms step_avg:69.10ms
step:4200/20000 train_loss:2.1073 train_time:290293ms step_avg:69.12ms
step:4400/20000 train_loss:2.0482 train_time:304096ms step_avg:69.11ms
step:4600/20000 train_loss:1.9118 train_time:317919ms step_avg:69.11ms
step:4800/20000 train_loss:2.2034 train_time:331744ms step_avg:69.11ms
step:5000/20000 train_loss:1.9566 train_time:345555ms step_avg:69.11ms
step:5000/20000 val_loss:2.0983 val_bpb:1.2427 train_time:345569ms step_avg:69.11ms
step:5200/20000 train_loss:2.1202 train_time:359369ms step_avg:69.11ms
step:5400/20000 train_loss:2.1386 train_time:373187ms step_avg:69.11ms
step:5600/20000 train_loss:2.1310 train_time:387000ms step_avg:69.11ms
step:5800/20000 train_loss:2.0856 train_time:400813ms step_avg:69.11ms
step:6000/20000 train_loss:2.1621 train_time:414616ms step_avg:69.10ms
step:6000/20000 val_loss:2.0853 val_bpb:1.2350 train_time:414629ms step_avg:69.10ms
step:6200/20000 train_loss:2.0259 train_time:428424ms step_avg:69.10ms
step:6400/20000 train_loss:2.0991 train_time:442210ms step_avg:69.10ms
step:6600/20000 train_loss:2.0433 train_time:455989ms step_avg:69.09ms
step:6800/20000 train_loss:2.0991 train_time:469773ms step_avg:69.08ms
step:7000/20000 train_loss:2.1395 train_time:483567ms step_avg:69.08ms
step:7000/20000 val_loss:2.0390 val_bpb:1.2076 train_time:483579ms step_avg:69.08ms
step:7200/20000 train_loss:2.0985 train_time:497364ms step_avg:69.08ms
step:7400/20000 train_loss:2.0154 train_time:511170ms step_avg:69.08ms
step:7600/20000 train_loss:1.8793 train_time:524979ms step_avg:69.08ms
step:7800/20000 train_loss:2.0212 train_time:538794ms step_avg:69.08ms
step:8000/20000 train_loss:1.9796 train_time:552599ms step_avg:69.07ms
step:8000/20000 val_loss:1.9822 val_bpb:1.1740 train_time:552612ms step_avg:69.08ms
step:8200/20000 train_loss:2.0431 train_time:566395ms step_avg:69.07ms
step:8400/20000 train_loss:1.9720 train_time:580294ms step_avg:69.08ms
step:8600/20000 train_loss:1.9745 train_time:594104ms step_avg:69.08ms
step:8685/20000 val_loss:1.9422 val_bpb:1.1503 train_time:599941ms step_avg:69.08ms
stopping_early: wallclock_cap train_time:599941ms step:8685/20000
peak memory allocated: 14328 MiB reserved: 14392 MiB
ema: loading EMA weights (decay=0.997)
Serialized model: 105789375 bytes
Code size: 61617 bytes
Total submission size: 105850992 bytes
Serialized model int8+zlib: 15399540 bytes (payload:27057508 raw_torch:27113999 payload_ratio:3.91x)
Total submission size int8+zlib: 15461157 bytes
final_int8_zlib_roundtrip val_loss:1.9517 val_bpb:1.1559 eval_time:2147ms eval_seq_len:2048
final_int8_zlib_roundtrip_exact val_loss:1.95166039 val_bpb:1.15588321
final_sliding_window val_loss:1.9145 val_bpb:1.1339 eval_time:201285ms stride:64 seq_len:2048
final_sliding_window_exact val_loss:1.91451721 val_bpb:1.13388793
Loading