|
| 1 | +# QAT Int5/Int6 + Backout + U-Net Skips + BigramHash(10240) + SWA50 |
| 2 | + |
| 3 | +**val_bpb: 1.1477** (seed=42, sliding window stride=64, post int5/int6+zstd quantization roundtrip, 15.94 MB) |
| 4 | + |
| 5 | +## Run Command |
| 6 | + |
| 7 | +```bash |
| 8 | +pip install zstandard # optional but recommended for better compression |
| 9 | +torchrun --standalone --nproc_per_node=8 train_gpt.py |
| 10 | +``` |
| 11 | + |
| 12 | +All parameters are set as defaults in `train_gpt.py`. No env vars needed. |
| 13 | +Falls back to zlib if zstandard is not installed. |
| 14 | + |
| 15 | +## Key Techniques |
| 16 | + |
| 17 | +### Quantization-Aware Training (QAT) |
| 18 | +During training, large weight matrices pass through a simulated quantization bottleneck using the |
| 19 | +straight-through estimator (STE). MLP weights see int5 noise, attention weights see int6 noise. |
| 20 | +The model learns to be robust to quantization, reducing the post-quant penalty from ~0.016 BPB |
| 21 | +to ~0.005 BPB — roughly 0.01 BPB free compared to post-training quantization alone. |
| 22 | + |
| 23 | +### Backout (Learned Residual Subtraction) |
| 24 | +A learned scalar λ (init=0.2) subtracts the midpoint layer's activation from the final output: |
| 25 | +`x = x - λ * x_mid`. Prevents over-reliance on early representations. |
| 26 | + |
| 27 | +### U-Net Skip Connections |
| 28 | +Encoder-decoder structure (5+5 layers) with learned per-dimension skip weights connecting |
| 29 | +encoder to decoder layers. |
| 30 | + |
| 31 | +### SVD Embedding Initialization |
| 32 | +Tied embeddings initialized via SVD spectral decay: singular values reshaped to follow a |
| 33 | +1/√k profile for better initial token representations. |
| 34 | + |
| 35 | +### Mixed Int5/Int6 Quantization + zstd-22 |
| 36 | +- Int5 [-16,15] for MLP weights (most compressible) |
| 37 | +- Int6 [-32,31] for attention weights (precision-sensitive) |
| 38 | +- FP16 for tied embeddings and last-2-layer key projections (Late-K) |
| 39 | + |
| 40 | +### BigramHash(10240) + SmearGate |
| 41 | +Hash consecutive token pairs into 10240-bucket embedding table (dim=128, projected to 512). |
| 42 | +SmearGate blends each token with the previous token's embedding. |
| 43 | + |
| 44 | +### Stochastic Weight Averaging |
| 45 | +SWA every 50 steps during warmdown (start_frac=0.4). Smoother weight distributions quantize better. |
| 46 | + |
| 47 | +## Architecture |
| 48 | +- 10 layers, 512 dim, 8 heads, 4 KV heads (GQA) |
| 49 | +- MLP 3x expansion (hidden=1536), relu² activation |
| 50 | +- Tied embeddings, logit softcap=30 |
| 51 | +- Orthogonal init with muP-scaled output projections |
| 52 | + |
| 53 | +## Hyperparameters |
| 54 | + |
| 55 | +| Parameter | Value | |
| 56 | +|-----------|-------| |
| 57 | +| matrix_lr (Muon) | 0.02 | |
| 58 | +| scalar_lr (AdamW) | 0.02 | |
| 59 | +| tied_embed_lr | 0.03 | |
| 60 | +| muon_weight_decay | 0.04 | |
| 61 | +| adamw_weight_decay | 0.01 | |
| 62 | +| warmdown_iters | 3000 | |
| 63 | +| swa_every / start_frac | 50 / 0.4 | |
| 64 | +| prune_frac | 0.08 | |
| 65 | +| eval_stride | 64 | |
| 66 | +| compressor | zstd-22 | |
| 67 | + |
| 68 | +Built on PR #162 by @unnir (SmearGate, BigramHash, OrthoInit) and techniques from @thwu1 and @raahilshah. |
0 commit comments