Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Late STE QAT + Int6 MLP3x + SmearGate + BigramHash + OrthoInit + Overtone + SWA + SGD TTT

## Score

**Measured (seed=1337, single run):** `val_bpb = 1.16292025` · `val_loss = 1.96353693` (after int6+zstd roundtrip + sliding-window eval; see `train.log`).

Trained on **8×H100 SXM** with a **600s** wallclock cap (`step=5464`). **Total submission size `15,948,643` bytes** (~15.95 MB decimal), **below** the 16,000,000-byte limit — **int6 + zstd-22** artifact plus UTF-8 `train_gpt.py` bytes (`64,426`).

> *Note:* The template-style multi-seed table below is **not** part of this folder’s logs; only **seed 1337** is recorded here. Re-run with other `SEED` values if you want a proper mean/std.

## Approach

Stacked techniques on a **9-layer, 512-dim** GPT-style model, plus **late STE QAT**, **Overtone-style init**, and optional **full-model SGD TTT** (this script defaults to SGD TTT on, LoRA TTT off).

### 1. Per-row int6 quantization + zstd-22

MLP and attention weight matrices are quantized to int6 (roughly `[-32, 31]`) with **per-row scaling**. Tied embeddings stay in a higher-precision path where it matters; the implementation follows the repo’s mixed-quant rules. After `torch.save` of the quantized payload, the blob is compressed with **zstd level 22** (`zstandard`), which is typically a few percent smaller than **zlib-9** on the same bytes — enough here to land **under** the decimal 16MB cap when zlib did not.

### 2. 3× MLP expansion

Hidden FFN width **1536** (3×) instead of 2× **1024**, paid for in the budget by int6 + strong compression.

### 3. SmearGate

A learned gate blending each token’s embedding with the **previous** token’s embedding for cheap bigram-like signal at the embedding layer (on the order of **~512** extra parameters in the usual setup).

### 4. BigramHash embedding

A **4096**-bucket table (e.g. dim **128**, projected to model width) keyed by adjacent token pairs via a small hash of `(prev, curr)`. Adds on the order of **~0.5M** parameters and complements SmearGate with an **additive** bigram path.

### 5. Orthogonal init (+ muP-style scaling)

Large matrices initialized orthogonal where applicable; readouts scaled with depth-aware factors consistent with muP-style training in this codebase.

### 6. Muon + AdamW, weight decay

**Muon** on matrix blocks with tuned **weight decay** and momentum schedule; scalar/embedding groups use **AdamW** with their own WD. This run uses **`muon_weight_decay=0.038`**, **`matrix_lr=0.025`** (see env overrides in `train_gpt.py`).

### 7. Stochastic weight averaging (SWA)

SWA accumulates weights over the **last fraction** of training (default **`swa_start_frac=0.5`**) every **`swa_every`** steps (default **`200`** in this script). The logged run averaged **5** checkpoints before quantization.

### 8. Late STE QAT (last ~15% of wallclock)

**Fake-quant (STE)** for int6 is only enabled after **`qat_start_frac≈0.85`** of the wallclock budget, with **`qat_lr_factor=0.5`** on the affected optimizer groups when QAT turns on — so Muon is not fighting quant noise for the whole run.

### 9. Full-model SGD test-time training (optional)

A short **SGD** pass on the validation stream (**not** LoRA) to adapt all weights, including gates and bigram paths LoRA often misses. Controlled by **`SGD_TTT_ENABLED`** / **`TTT_LORA_ENABLED`**.

## Main Hyperparameters

| Parameter | Value (this script / logged run) |
|-----------|----------------------------------|
| num_layers | 9 |
| model_dim | 512 |
| mlp_mult | 3.0 (hidden=1536) |
| train_seq_len | 2048 |
| train_batch_tokens | 786,432 |
| warmdown_iters | 3000 |
| matrix_lr | 0.025 |
| scalar_lr | 0.02 |
| tied_embed_lr | 0.03 |
| muon_momentum | 0.99 (warmup from 0.92 over 1500 steps) |
| muon_weight_decay | 0.038 |
| weight_decay (AdamW scalars) | 0.01 |
| grad_clip_norm | 0.3 |
| eval_stride | 64 |
| swa_every | 200 |
| swa_start_frac | 0.5 |
| qat_start_frac | 0.85 |
| qat_lr_factor | 0.5 |
| bigram hash buckets | 4096 |
| bigram dim | 128 |
| compressor | **zstd (level 22)** |
| SGD TTT | LR `3e-4`, momentum `0.95` (when enabled) |

## Key metrics (this snapshot)

| Item | Value |
|------|--------|
| **val_bpb** | **1.16292025** (`final_int8_zstd_roundtrip_exact`) |
| **val_loss** | **1.96353693** |
| Wallclock cap | 600s |
| Steps completed | 5464 |
| Model params (logged) | ~22.37M |
| **bytes_total** | **15,948,643** (under 16MB cap) |
| **bytes_code** | **64,426** |
| int6+zstd blob (logged) | 15,884,217 bytes |

## Reproducibility

**Logged run** (seed **1337**):

| Seed | val_loss | val_bpb |
|------|----------|---------|
| 1337 | 1.96353693 | 1.16292025 |

For multiple seeds, re-launch with e.g. `SEED=42`, `SEED=7`, etc. Byte totals and BPB can shift slightly across machines due to GPU non-determinism.

## Evaluation pipeline (order)

1. Train until the 600s cap (late QAT only in the tail).
2. Apply SWA checkpoint average.
3. Quantize to int6 + **zstd-22** → `final_model.int8.ptz`.
4. Decompress, dequantize, **sliding-window eval** (`eval_stride=64`).
5. If enabled: **SGD TTT**, then final metrics.

## How to reproduce

Install **zstandard** and cache FineWeb (`sp1024`) from the repo root; set **`HF_TOKEN`** if downloads require it.

```bash
pip install zstandard
export HF_TOKEN="your_token" # if needed
python3 data/cached_challenge_fineweb.py --variant sp1024
```

```bash
cd /path/to/parameter-golf

RUN_ID=late_qat_sgd_ttt_zstd \
SEED=1337 \
DATA_PATH=./data/datasets/fineweb10B_sp1024/ \
TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \
VOCAB_SIZE=1024 \
EVAL_STRIDE=64 \
SGD_TTT_ENABLED=1 \
TTT_LORA_ENABLED=0 \
torchrun --standalone --nproc_per_node=8 \
old/20/03/26-zstandard/train_gpt.py
```

## Files in this folder

| File | Purpose |
|------|---------|
| `train_gpt.py` | Training + zstd artifact |
| `train.log` | Log for the run above |
| `submission.json` | Summary JSON for the challenge |
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"track": "10min_16mb",
"date": "2026-03-20",
"name": "Late STE QAT + Int6 MLP3x + SmearGate + BigramHash + OrthoInit + Overtone + SWA + SGD TTT",
"author": "David Puertolas Merenciano",
"github_id": "davidpuertolas",
"blurb": "Late STE QAT (last 15%, per #76) avoids Muon momentum corruption while closing quant gap. Full-model SGD TTT (per #152) replaces LoRA TTT which hurts with SmearGate (#178). WD=0.038 + LR=0.025 from best validated submissions (#179, #194). Artifact: int6+zstd-22, under 16MB cap.",
"val_loss": 1.96353693,
"val_bpb": 1.16292025,
"bytes_total": 15948643,
"bytes_code": 64426
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
model_params:22368841
world_size:8 grad_accum_steps:1
ste_qat:late_activation at 85% wallclock, lr_factor=0.5
muon_weight_decay:0.038 matrix_lr:0.025
train_batch_tokens:786432 train_seq_len:2048 iterations:20000 max_wallclock:600s
seed:1337
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:6.9311 val_bpb:4.1050 train_time:0ms step_avg:0.02ms
step:1/20000 train_loss:6.9332 train_time:137ms step_avg:137.14ms
step:2/20000 train_loss:8.4116 train_time:197ms step_avg:98.25ms
step:3/20000 train_loss:7.4632 train_time:290ms step_avg:96.55ms
step:4/20000 train_loss:7.3430 train_time:383ms step_avg:95.72ms
step:5/20000 train_loss:7.5556 train_time:476ms step_avg:95.25ms
step:6/20000 train_loss:7.6005 train_time:569ms step_avg:94.91ms
step:7/20000 train_loss:7.3236 train_time:664ms step_avg:94.79ms
step:8/20000 train_loss:6.9358 train_time:756ms step_avg:94.55ms
step:9/20000 train_loss:6.5327 train_time:849ms step_avg:94.37ms
step:10/20000 train_loss:6.2320 train_time:942ms step_avg:94.24ms
step:100/20000 train_loss:3.2125 train_time:8846ms step_avg:88.46ms
step:200/20000 train_loss:2.3915 train_time:19276ms step_avg:96.38ms
step:300/20000 train_loss:2.5587 train_time:29727ms step_avg:99.09ms
step:400/20000 train_loss:2.4284 train_time:39989ms step_avg:99.97ms
step:500/20000 train_loss:2.4169 train_time:48791ms step_avg:97.58ms
step:500/20000 val_loss:2.3788 val_bpb:1.4089 train_time:48833ms step_avg:97.67ms
step:600/20000 train_loss:2.3541 train_time:59248ms step_avg:98.75ms
step:700/20000 train_loss:2.3678 train_time:69725ms step_avg:99.61ms
step:800/20000 train_loss:2.2615 train_time:79977ms step_avg:99.97ms
step:900/20000 train_loss:2.1514 train_time:90430ms step_avg:100.48ms
step:1000/20000 train_loss:2.2985 train_time:99231ms step_avg:99.23ms
step:1000/20000 val_loss:2.2502 val_bpb:1.3327 train_time:99274ms step_avg:99.27ms
step:1100/20000 train_loss:2.3433 train_time:109642ms step_avg:99.67ms
step:1200/20000 train_loss:2.3797 train_time:120040ms step_avg:100.03ms
step:1300/20000 train_loss:2.1229 train_time:130463ms step_avg:100.36ms
step:1400/20000 train_loss:2.2070 train_time:140816ms step_avg:100.58ms
step:1500/20000 train_loss:2.2455 train_time:149603ms step_avg:99.74ms
step:1500/20000 val_loss:2.2079 val_bpb:1.3076 train_time:149645ms step_avg:99.76ms
step:1600/20000 train_loss:2.1007 train_time:160074ms step_avg:100.05ms
step:1700/20000 train_loss:2.1664 train_time:170545ms step_avg:100.32ms
step:1800/20000 train_loss:2.1823 train_time:181020ms step_avg:100.57ms
step:1900/20000 train_loss:2.1516 train_time:189803ms step_avg:99.90ms
step:2000/20000 train_loss:2.0915 train_time:200201ms step_avg:100.10ms
step:2000/20000 val_loss:2.1553 val_bpb:1.2765 train_time:200243ms step_avg:100.12ms
step:2100/20000 train_loss:2.0728 train_time:210725ms step_avg:100.35ms
step:2200/20000 train_loss:2.1654 train_time:221232ms step_avg:100.56ms
step:2300/20000 train_loss:2.1299 train_time:231688ms step_avg:100.73ms
step:2400/20000 train_loss:2.0915 train_time:240492ms step_avg:100.20ms
step:2500/20000 train_loss:2.1899 train_time:250967ms step_avg:100.39ms
step:2500/20000 val_loss:2.1292 val_bpb:1.2610 train_time:251009ms step_avg:100.40ms
step:2600/20000 train_loss:2.1329 train_time:261376ms step_avg:100.53ms
step:2700/20000 train_loss:2.1255 train_time:271801ms step_avg:100.67ms
step:2800/20000 train_loss:2.1792 train_time:282179ms step_avg:100.78ms
step:2900/20000 train_loss:2.0496 train_time:290997ms step_avg:100.34ms
step:3000/20000 train_loss:2.1832 train_time:301357ms step_avg:100.45ms
step:3000/20000 val_loss:2.1154 val_bpb:1.2529 train_time:301399ms step_avg:100.47ms
step:3100/20000 train_loss:2.0577 train_time:311765ms step_avg:100.57ms
step:3200/20000 train_loss:2.1901 train_time:322125ms step_avg:100.66ms
step:3300/20000 train_loss:2.0864 train_time:330911ms step_avg:100.28ms
step:3400/20000 train_loss:2.0324 train_time:341219ms step_avg:100.36ms
step:3500/20000 train_loss:2.1902 train_time:351632ms step_avg:100.47ms
step:3500/20000 val_loss:2.0927 val_bpb:1.2394 train_time:351674ms step_avg:100.48ms
step:3600/20000 train_loss:2.1049 train_time:362032ms step_avg:100.56ms
step:3700/20000 train_loss:2.1037 train_time:372242ms step_avg:100.61ms
step:3800/20000 train_loss:2.0835 train_time:381033ms step_avg:100.27ms
step:3900/20000 train_loss:2.0819 train_time:391431ms step_avg:100.37ms
step:4000/20000 train_loss:1.9795 train_time:401794ms step_avg:100.45ms
step:4000/20000 val_loss:2.0721 val_bpb:1.2272 train_time:401836ms step_avg:100.46ms
step:4100/20000 train_loss:2.0220 train_time:412181ms step_avg:100.53ms
step:4200/20000 train_loss:2.1582 train_time:422467ms step_avg:100.59ms
step:4300/20000 train_loss:2.0639 train_time:431253ms step_avg:100.29ms
step:4400/20000 train_loss:2.0361 train_time:441677ms step_avg:100.38ms
step:4500/20000 train_loss:2.1269 train_time:451993ms step_avg:100.44ms
step:4500/20000 val_loss:2.0494 val_bpb:1.2138 train_time:452034ms step_avg:100.45ms
swa:start step:4600
step:4600/20000 train_loss:1.8460 train_time:462381ms step_avg:100.52ms
step:4700/20000 train_loss:2.2388 train_time:471273ms step_avg:100.27ms
step:4800/20000 train_loss:2.4293 train_time:481774ms step_avg:100.37ms
step:4900/20000 train_loss:2.0529 train_time:492243ms step_avg:100.46ms
step:5000/20000 train_loss:2.1032 train_time:502591ms step_avg:100.52ms
step:5000/20000 val_loss:2.0252 val_bpb:1.1994 train_time:502682ms step_avg:100.54ms
qat:activated step:5084 elapsed:536517ms (89.4% wallclock)
step:5100/20000 train_loss:2.1292 train_time:564634ms step_avg:110.71ms
step:5200/20000 train_loss:2.0344 train_time:573415ms step_avg:110.27ms
step:5300/20000 train_loss:2.0009 train_time:583883ms step_avg:110.17ms
step:5400/20000 train_loss:2.0437 train_time:594307ms step_avg:110.06ms
step:5464/20000 val_loss:1.9995 val_bpb:1.1842 train_time:600015ms step_avg:109.81ms
stopping_early: wallclock_cap train_time:600015ms step:5464/20000
peak memory allocated: 21470 MiB reserved: 22170 MiB
swa:applying averaged 5 checkpoints
Serialized model: 87413062 bytes
Code size: 64426 bytes
Total submission size: 87477488 bytes
Serialized model int6+zstd: 15884217 bytes
Total submission size: 15948643 bytes
final_eval_mode:sliding_window stride:64
final_int8_zstd_roundtrip val_loss:1.9635 val_bpb:1.1629 eval_time:235968ms
final_int8_zstd_roundtrip_exact val_loss:1.96353693 val_bpb:1.16292025
Loading