Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# 10L Int5-MLP + SmearGate + BigramHash + Late QAT

Record candidate, ready to run. `submission.json` and `train.log` are placeholders until we have three real 8xH100 SXM runs logged.

## Why this combination

We're pulling together the pieces that have actually worked across the top submissions:

- mixed int5 MLP / int6 attention export, which buys enough artifact budget for a 10th layer
- SmearGate + BigramHash for cheap token-pair context
- orthogonal init with muP-style output projection scaling
- decoupled Muon weight decay at 0.04
- SWA during warmdown
- late QAT kicking in at 85% of wallclock (not always-on STE, which has consistently underperformed)
- sliding-window eval, stride 64, with full-tail handling

The idea is straightforward. Start from the best public 10L mixed-precision stack. Keep the local-context gains that SmearGate runs have over the older int6/MLP3x cluster. And only add the late-stage export-aware training that keeps beating full-run STE in head-to-head comparisons.

## Default recipe

- `NUM_LAYERS=10`
- `MODEL_DIM=512`
- `NUM_HEADS=8`
- `NUM_KV_HEADS=4`
- `MLP_MULT=3.0`
- `TRAIN_BATCH_TOKENS=786432`
- `TRAIN_SEQ_LEN=2048`
- `EVAL_SEQ_LEN=2048`
- `EVAL_STRIDE=64`
- `BIGRAM_VOCAB_SIZE=4096`
- `BIGRAM_DIM=128`
- `MATRIX_LR=0.025`
- `SCALAR_LR=0.02`
- `TIED_EMBED_LR=0.03`
- `MUON_WEIGHT_DECAY=0.04`
- `ADAM_WEIGHT_DECAY=0.01`
- `SWA_ENABLED=1`
- `SWA_START_FRAC=0.5`
- `SWA_EVERY=50`
- `QAT_ENABLED=1`
- `QAT_START_FRAC=0.85`
- `KEEP_LAST_K_FP16=1`
- `REQUIRE_ZSTD=1`

## Before you run

The script expects `zstandard` so the artifact gets zstd-22 compression. Install it first:

```bash
pip install zstandard
```

If you just want a quick smoke test, `REQUIRE_ZSTD=0` falls back to zlib. But that won't match the intended record setup.

## Reproduction

```bash
RUN_ID=10l_int5mlp_smearbigram_lateqat_seed1337 \
DATA_PATH=./data/datasets/fineweb10B_sp1024 \
TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \
VOCAB_SIZE=1024 \
SEED=1337 \
pip install zstandard && \
torchrun --standalone --nproc_per_node=8 \
./records/track_10min_16mb/2026-03-20_10L_Int5MLP_SmearBigram_LateQAT/train_gpt.py
```

Three-seed sweep:

```bash
for SEED in 1337 42 7; do
RUN_ID=10l_int5mlp_smearbigram_lateqat_seed${SEED} \
DATA_PATH=./data/datasets/fineweb10B_sp1024 \
TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \
VOCAB_SIZE=1024 \
SEED=${SEED} \
torchrun --standalone --nproc_per_node=8 \
./records/track_10min_16mb/2026-03-20_10L_Int5MLP_SmearBigram_LateQAT/train_gpt.py
done
```

## What to record after the runs

- `final_mixed_int5int6_roundtrip_exact val_loss:<...> val_bpb:<...>`
- `Serialized model mixed-int5int6+zstd:<...>bytes`
- `Total submission size mixed-int5int6:<...>bytes`
- last train step reached under the 600s cap
- eval wall time

## Files

- `train_gpt.py` is the record candidate script
- `train.log` is a placeholder, replace it with real multi-seed logs
- `submission.json` is the metadata template, fill it in after the runs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"author": "Christopher Buckley",
"github_id": "chris-buckley",
"name": "10L Int5-MLP + SmearGate + BigramHash + Late QAT",
"blurb": "Mixed int5/int6 export funds a 10-layer ReLU^2 model with SmearGate, BigramHash, orthogonal init, Muon weight decay, SWA, and late mixed QAT under the 16MB cap.",
"date": "2026-03-20",
"val_loss": 1.96338430,
"val_bpb": 1.16282670,
"pre_quant_val_loss": 2.0105,
"pre_quant_val_bpb": 1.1907,
"step_stop": 4354,
"wallclock_seconds": 602,
"eval_time_seconds": 171,
"bytes_total": 15481841,
"bytes_model_mixed_int5int6": 15425120,
"bytes_code": 56721,
"notes": "Single seed (1337) on 8xH100 SXM. SWA averaged 15 checkpoints. Did not beat MLP3x (1.1598). Needs seeds 42 and 7 for statistical significance if submitting."
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
Seed 1337 — 8xH100 SXM — 2026-03-20

val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021845
model_params:24730705
world_size:8 grad_accum_steps:1
attention_mode:gqa num_heads:8 num_kv_heads:4
tie_embeddings:True embed_lr:0.03 matrix_lr:0.025 scalar_lr:0.02
train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000
eval_seq_len:2048 eval_stride:64 eval_batch_seqs:32 muon_weight_decay:0.04 adam_weight_decay:0.01 qat_enabled:True qat_start_frac:0.85 keep_last_k_fp16:1
seed:1337

step:0/20000 val_loss:6.9283 val_bpb:4.1033 train_time:0ms step_avg:0.02ms
step:1/20000 train_loss:6.9295 train_time:10084ms step_avg:10084.14ms
step:2/20000 train_loss:7.9521 train_time:10159ms step_avg:5079.38ms
step:3/20000 train_loss:7.9764 train_time:10258ms step_avg:3419.17ms
step:4/20000 train_loss:7.8828 train_time:10355ms step_avg:2588.83ms
step:5/20000 train_loss:7.6879 train_time:10457ms step_avg:2091.40ms
step:6/20000 train_loss:7.4733 train_time:10556ms step_avg:1759.29ms
step:7/20000 train_loss:7.1612 train_time:10653ms step_avg:1521.89ms
step:8/20000 train_loss:6.8910 train_time:10753ms step_avg:1344.09ms
step:9/20000 train_loss:6.5349 train_time:10850ms step_avg:1205.59ms
step:10/20000 train_loss:6.2947 train_time:10949ms step_avg:1094.88ms
step:200/20000 train_loss:2.4592 train_time:30770ms step_avg:153.85ms
step:400/20000 train_loss:2.4406 train_time:54762ms step_avg:136.91ms
step:600/20000 train_loss:2.3558 train_time:76052ms step_avg:126.75ms
step:800/20000 train_loss:2.2588 train_time:101283ms step_avg:126.60ms
step:1000/20000 train_loss:2.2898 train_time:122568ms step_avg:122.57ms
step:1000/20000 val_loss:2.2448 val_bpb:1.3295 train_time:122623ms step_avg:122.62ms
step:1200/20000 train_loss:2.3674 train_time:146272ms step_avg:121.89ms
step:1400/20000 train_loss:2.1949 train_time:170151ms step_avg:121.54ms
step:1600/20000 train_loss:2.0863 train_time:191217ms step_avg:119.51ms
step:1800/20000 train_loss:2.1669 train_time:216406ms step_avg:120.23ms
step:2000/20000 train_loss:2.0792 train_time:238186ms step_avg:119.09ms
step:2000/20000 val_loss:2.1408 val_bpb:1.2679 train_time:238223ms step_avg:119.11ms
step:2200/20000 train_loss:2.1475 train_time:262401ms step_avg:119.27ms
step:2400/20000 train_loss:2.0647 train_time:283002ms step_avg:117.92ms
step:2600/20000 train_loss:2.1022 train_time:307114ms step_avg:118.12ms
step:2800/20000 train_loss:2.1452 train_time:331363ms step_avg:118.34ms
step:3000/20000 train_loss:2.1465 train_time:352108ms step_avg:117.37ms
step:3000/20000 val_loss:2.0762 val_bpb:1.2297 train_time:352148ms step_avg:117.38ms
step:3200/20000 train_loss:2.1522 train_time:375805ms step_avg:117.44ms
step:3400/20000 train_loss:1.9978 train_time:396620ms step_avg:116.65ms
step:3600/20000 train_loss:2.0664 train_time:421123ms step_avg:116.98ms
swa:start step:3650
step:3800/20000 train_loss:2.0404 train_time:442619ms step_avg:116.48ms
step:4000/20000 train_loss:1.9381 train_time:467484ms step_avg:116.87ms
step:4000/20000 val_loss:2.0281 val_bpb:1.2012 train_time:467567ms step_avg:116.89ms
step:4200/20000 train_loss:2.1112 train_time:493201ms step_avg:117.43ms
step:4354/20000 val_loss:2.0105 val_bpb:1.1907 train_time:601991ms step_avg:138.26ms
stopping_early: wallclock_cap train_time:601991ms step:4354/20000
peak memory allocated: 18864 MiB reserved: 20234 MiB
swa:applying averaged 15 checkpoints
Serialized model: 96864555 bytes
Code size: 56721 bytes
Total submission size: 96921276 bytes
Serialized model mixed-int5int6+zstd: 15425120 bytes
Total submission size mixed-int5int6: 15481841 bytes
final_eval_mode:sliding_window seq_len:2048 stride:64 batch_seqs:32
final_mixed_int5int6_roundtrip val_loss:1.9634 val_bpb:1.1628 eval_time:170665ms
final_mixed_int5int6_roundtrip_exact val_loss:1.96338430 val_bpb:1.16282670
Loading