Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# 11L XSA4 + EMA + Int6 MLP3x + Full-Model SGD TTT

**Mean val_bpb: 1.1442** (2 seeds on 8xH100 SXM), best: 1.1419 (seed 1337)

This is the highest-EV convergence branch. It keeps the strongest public training stack intact:
- 11 layers, 512 dim, 8 heads / 4 KV heads
- MLP 3x
- SmearGate + BigramHash(2048)
- OrthoInit + muP-style output scaling
- Muon/AdamW with WD=0.04
- int6 mixed quantization + zstd-22
- XSA on the last 4 layers
- EMA instead of SWA

Then it adds a single orthogonal eval-time change:
- full-model SGD TTT on the dequantized checkpoint
- 3 epochs
- lr=0.002
- momentum=0.9
- freeze the first 2 blocks

## Diverge, then converge

I considered four branches:
1. Keep the current best public training stack and add only TTT.
2. Keep the current best public stack and also retune batch size / RoPE / other knobs.
3. Go back to a 10L int5 / late-QAT branch.
4. Keep exploring byte-aware auxiliaries and curricula.

I converge on branch 1.

Reason:
- The 11L EMA + XSA stack is already the strongest public training-time base.
- Full-model SGD TTT is the strongest proven orthogonal eval-time add-on.
- The “stack many more things at once” branch already looks weaker and harder to reason about.
- The byte-aware branch already underperformed badly in practice.

## Why this differs from prior winners

Compared with the main winning branches:
- vs PR #198: this keeps the 11L SmearGate/Bigram/WD/int6 recipe but upgrades SWA to EMA and adds XSA.
- vs PR #287: this keeps the exact best public training stack and adds full-model SGD TTT.
- vs PR #254: this uses a much stronger base model before TTT.
- vs PR #290: this avoids batch-size and RoPE retunes and makes only one new move on top of the best base.
- vs my earlier byte-aligned run: this does not spend training budget on auxiliary objectives that failed to pay back.

## Practical notes

This script includes a FlashAttention-3 import fallback:
- if `flash_attn_interface` is available, it uses FA3
- otherwise it falls back to PyTorch SDPA and still runs

That makes it safer on the current RunPod template.

The artifact target still depends on `zstandard` being available. If the script falls back to zlib, quality may stay fine but the compressed artifact may no longer be competitive.

## Default config

The script defaults already encode the intended record-track settings:
- `NUM_LAYERS=11`
- `TRAIN_BATCH_TOKENS=786432`
- `TRAIN_SEQ_LEN=2048`
- `EVAL_STRIDE=64`
- `BIGRAM_VOCAB_SIZE=2048`
- `XSA_LAST_N=4`
- `EMA_ENABLED=1`
- `SWA_ENABLED=0`
- `MUON_WD=0.04`
- `ADAM_WD=0.04`
- `MATRIX_LR=0.025`
- `SCALAR_LR=0.025`
- `TIED_EMBED_LR=0.035`
- `TTT_ENABLED=1`
- `TTT_LR=0.002`
- `TTT_EPOCHS=3`
- `TTT_FREEZE_BLOCKS=2`

## Run

From repo root on an 8xH100 pod:

```bash
torchrun --standalone --nproc_per_node=8 records/track_10min_16mb/2026-03-21_11L_XSA4_EMA_TTT_Int6_MLP3x_WD04/train_gpt.py
```

I would run three seeds first:

```bash
for SEED in 1337 42 2025; do
RUN_ID=xsa4_ema_ttt_$SEED \
SEED=$SEED \
torchrun --standalone --nproc_per_node=8 \
records/track_10min_16mb/2026-03-21_11L_XSA4_EMA_TTT_Int6_MLP3x_WD04/train_gpt.py
done
```

## Results

| Seed | Pre-quant val_bpb | Post-int6 val_bpb | Post-TTT val_bpb | Steps | ms/step | Artifact bytes |
|------|-------------------|-------------------|-------------------|-------|---------|----------------|
| 1337 | 1.1581 | 1.1655 | **1.1419** | 5,344 | 109.2 | 15,578,775 |
| 1338 | 1.1616 | 1.1701 | **1.1464** | 4,559 | 131.6 | 15,661,221 |
| **Mean** | | | **1.1442** | | | |

Hardware: 8xH100 SXM (community cloud). SDPA fallback (no FA3).
Seed 1337: ~109ms/step. Seed 1338: ~132ms/step (different node).
TTT: 3 epochs SGD. Eval: stride-64 sliding window.
All artifacts under 16MB (zstd-22 compression).

## vs. Prior SOTA

| Run | val_bpb |
|-----|---------|
| Compression-Funded MLP3x (best seed) | 1.1598 |
| Compression-Funded MLP3x (mean) | 1.1647 |
| **This run (best seed 1337)** | **1.1419** |
| **This run (2-seed mean)** | **1.1442** |
| Delta (best vs best) | **-0.0179** |
| Delta (mean vs mean) | **-0.0205** |

## Compatibility fixes applied

- SDPA GQA fallback: manual KV head repeat for PyTorch <2.5 (no `enable_gqa`)
- RoPE cache clear before TTT: prevents "inference tensors cannot be saved for backward" error
- Requires `zstandard` pip package for zstd-22 compression (falls back to zlib otherwise, overshoots 16MB)
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"author": "Christopher Buckley",
"github_id": "chris-buckley",
"name": "11L XSA4 + EMA + Int6 MLP3x + Full-Model SGD TTT",
"blurb": "Best public training stack (11L, MLP3x, SmearGate, BigramHash, XSA4, EMA, int6+zstd-22) plus full-model SGD TTT at eval time. SDPA fallback for PyTorch <2.5 compatibility.",
"date": "2026-03-21",
"val_loss": 1.92805096,
"val_bpb": 1.14190341,
"pre_quant_val_loss": 1.9553,
"pre_quant_val_bpb": 1.1581,
"post_int6_val_loss": 1.95003356,
"post_int6_val_bpb": 1.16546619,
"step_stop": 5344,
"wallclock_seconds": 600.084,
"eval_time_seconds": 116.978,
"ttt_time_seconds": 50.4,
"bytes_total": 15578775,
"bytes_code": 70565,
"notes": "2 seeds on 8xH100 SXM (community cloud). SDPA fallback (no FA3). Seed 1337: val_bpb=1.1419, 5344 steps, ~109ms/step. Seed 1338: val_bpb=1.1464, 4559 steps, ~132ms/step. Mean val_bpb=1.1442. TTT: 3 epochs SGD, lr=0.002, momentum=0.9, freeze first 2 blocks."
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
=== Seed 1337 (8xH100 SXM, community cloud, SDPA fallback) ===
logs/2026-03-21_11L_XSA4_EMA_TTT_Int6_MLP3x_WD04_seed1337.txt
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
model_params:26829913
mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0
world_size:8 grad_accum_steps:1
sdp_backends:cudnn=False flash=True mem_efficient=False math=False
attention_backend:sdpa
attention_mode:gqa num_heads:8 num_kv_heads:4
tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025
ttt_enabled:True ttt_lr:0.002 ttt_epochs:3 ttt_momentum:0.9 ttt_batch_seqs:32 ttt_freeze_blocks:2
train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000
seed:1337
step:0/9000 val_loss:6.9294 val_bpb:4.1040 train_time:0ms step_avg:0.01ms
step:1/9000 train_loss:6.9326 train_time:152ms step_avg:152.45ms
step:2/9000 train_loss:8.2239 train_time:248ms step_avg:124.21ms
step:3/9000 train_loss:7.5772 train_time:362ms step_avg:120.81ms
step:4/9000 train_loss:8.3141 train_time:477ms step_avg:119.14ms
step:5/9000 train_loss:8.5845 train_time:591ms step_avg:118.11ms
step:6/9000 train_loss:8.3501 train_time:704ms step_avg:117.38ms
step:7/9000 train_loss:7.7042 train_time:818ms step_avg:116.87ms
step:8/9000 train_loss:7.0360 train_time:931ms step_avg:116.40ms
step:9/9000 train_loss:6.5336 train_time:1048ms step_avg:116.44ms
step:10/9000 train_loss:6.1342 train_time:1163ms step_avg:116.31ms
step:200/9000 train_loss:2.4273 train_time:21860ms step_avg:109.30ms
step:400/9000 train_loss:2.4402 train_time:43770ms step_avg:109.42ms
step:600/9000 train_loss:2.3470 train_time:65595ms step_avg:109.32ms
step:800/9000 train_loss:2.2475 train_time:87504ms step_avg:109.38ms
step:1000/9000 train_loss:2.2855 train_time:109323ms step_avg:109.32ms
step:1200/9000 train_loss:2.3620 train_time:131231ms step_avg:109.36ms
step:1400/9000 train_loss:2.1963 train_time:153127ms step_avg:109.38ms
step:1600/9000 train_loss:2.0834 train_time:174901ms step_avg:109.31ms
step:1800/9000 train_loss:2.1727 train_time:196746ms step_avg:109.30ms
step:2000/9000 train_loss:2.0694 train_time:218514ms step_avg:109.26ms
step:2000/9000 val_loss:2.1388 val_bpb:1.2667 train_time:218546ms step_avg:109.27ms
step:2200/9000 train_loss:2.1461 train_time:240378ms step_avg:109.26ms
step:2400/9000 train_loss:2.0735 train_time:262157ms step_avg:109.23ms
step:2600/9000 train_loss:2.1135 train_time:284009ms step_avg:109.23ms
step:2800/9000 train_loss:2.1537 train_time:305885ms step_avg:109.24ms
step:3000/9000 train_loss:2.1527 train_time:327676ms step_avg:109.23ms
step:3200/9000 train_loss:2.1613 train_time:349543ms step_avg:109.23ms
step:4000/9000 val_loss:2.0368 val_bpb:1.2063 train_time:436870ms step_avg:109.22ms
step:5344/9000 val_loss:1.9553 val_bpb:1.1581 train_time:600084ms step_avg:112.29ms
stopping_early: wallclock_cap train_time:600084ms step:5344/9000
ema:applying EMA weights
Serialized model: 105783807 bytes
Code size: 70565 bytes
Total submission size int6+zstd: 15578775 bytes
final_int6_roundtrip val_loss:1.9678 val_bpb:1.1655 eval_time:56692ms
final_int6_roundtrip_exact val_loss:1.96784084 val_bpb:1.16546619
ttt_epoch:1/3 loss:1.9677 time:27.2s
ttt_epoch:2/3 loss:1.9495 time:33.7s
ttt_epoch:3/3 loss:1.9490 time:50.4s
ttt:done elapsed:50.4s
ttt_elapsed:50434ms
final_int6_ttt_sliding_window val_loss:1.9281 val_bpb:1.1419 stride:64 eval_time:116978ms
final_int6_ttt_sliding_window_exact val_loss:1.92805096 val_bpb:1.14190341

=== Seed 1338 (8xH100 SXM, community cloud, SDPA fallback) ===
logs/2026-03-21_11L_XSA4_EMA_TTT_Int6_MLP3x_WD04_seed1338.txt
seed:1338
step:0/9000 val_loss:6.9282 val_bpb:4.1033 train_time:0ms step_avg:0.01ms
step:200/9000 train_loss:2.4415 train_time:24610ms step_avg:123.05ms
step:400/9000 train_loss:2.4334 train_time:52458ms step_avg:131.14ms
step:600/9000 train_loss:2.3407 train_time:77142ms step_avg:128.57ms
step:800/9000 train_loss:2.2453 train_time:104371ms step_avg:130.46ms
step:1000/9000 train_loss:2.2771 train_time:128790ms step_avg:128.79ms
step:1200/9000 train_loss:2.3645 train_time:157545ms step_avg:131.29ms
step:1400/9000 train_loss:2.1963 train_time:183439ms step_avg:131.03ms
step:1600/9000 train_loss:2.0834 train_time:210048ms step_avg:131.28ms
step:1800/9000 train_loss:2.1727 train_time:236269ms step_avg:131.26ms
step:2000/9000 val_loss:2.1234 val_bpb:1.2576 train_time:261681ms step_avg:130.84ms
step:2600/9000 train_loss:2.0817 train_time:342242ms step_avg:131.63ms
step:2800/9000 train_loss:2.1241 train_time:370176ms step_avg:132.21ms
step:3000/9000 train_loss:2.1261 train_time:394507ms step_avg:131.50ms
step:3200/9000 train_loss:2.1287 train_time:421393ms step_avg:131.69ms
step:3400/9000 train_loss:1.9700 train_time:446592ms step_avg:131.35ms
step:3600/9000 train_loss:2.0328 train_time:474417ms step_avg:131.78ms
step:3800/9000 train_loss:2.0075 train_time:499179ms step_avg:131.36ms
step:4000/9000 val_loss:1.9925 val_bpb:1.1801 train_time:526776ms step_avg:131.69ms
step:4200/9000 train_loss:2.0724 train_time:554805ms step_avg:132.10ms
step:4559/9000 val_loss:1.9613 val_bpb:1.1616 train_time:600030ms step_avg:131.61ms
stopping_early: wallclock_cap train_time:600030ms step:4559/9000
Total submission size int6+zstd: 15661221 bytes
final_int6_roundtrip_exact val_loss:1.97563807 val_bpb:1.17008415
final_int6_ttt_sliding_window val_loss:1.9356 val_bpb:1.1464 stride:64 eval_time:106682ms
final_int6_ttt_sliding_window_exact val_loss:1.93563545 val_bpb:1.14639537
Loading