Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
Int6 mixed quantization with STE fake-int6 QAT, 3x MLP expansion, NorMuon optimizer, SWA checkpoint averaging, and sliding window eval.

## what changed

**MLP 3x expansion (hidden=1536)**: 21.8M params. The extra capacity is paid for by int6 quantization.

**STE fake-int6 QAT**: weights are fake-quantized to int6 and dequantized via straight-through estimator throughout training. The model learns weight distributions that survive 6-bit export, reducing quantization penalty from ~0.008 to ~0.001 BPB.

**NorMuon optimizer**: per-neuron row-wise RMS normalization after Newton-Schulz orthogonalization. Stabilizes updates across neurons with different activation scales.

**SWA checkpoint averaging**: collects model checkpoints every 200 steps during warmdown and uniformly averages them. Finds flatter minima than EMA.

**Mixed quantization**: int6 per-row on MLP and attention weights, fp16 passthrough for the tied embedding, zstd-22 compression.

**Sliding window eval (stride=64)**: each token scored with nearly full context.

**seq_len=2048**, **batch=786K**, **grad_clip=0.3**, **matrix_lr=0.02**, **Muon momentum=0.99** (warmup from 0.92 over 1500 steps), **Muon weight decay=0.01**, **warmdown=3000 iters**, **logit softcap=15**.

## config

```
VOCAB_SIZE=1024 NUM_LAYERS=9 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4
MLP_MULT=3 TIE_EMBEDDINGS=1 TRAIN_SEQ_LEN=2048 TRAIN_BATCH_TOKENS=786432
EMBED_LR=0.03 MATRIX_LR=0.02 SCALAR_LR=0.02 TIED_EMBED_LR=0.03
MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_START=0.92 MUON_MOMENTUM_WARMUP_STEPS=1500
MUON_WEIGHT_DECAY=0.01 WARMDOWN_ITERS=3000 LOGIT_SOFTCAP=15
EVAL_STRIDE=64 GRAD_CLIP_NORM=0.3 ENABLE_QAT=1 EMA_DECAY=0.998
```

## run command

```bash
pip install zstandard
torchrun --standalone --nproc_per_node=8 train_gpt.py
```

## results

8xH100 80GB HBM3 (Modal, 10 min wallclock, seed 1337):

| metric | val_loss | val_bpb | artifact |
|--------|----------|---------|----------|
| pre-quant (raw) | 2.007 | 1.1887 | — |
| post-quant (standard) | 2.0055 | 1.1877 | 15.22 MB |
| **post-quant (sliding window, stride=64)** | **1.9697** | **1.1666** | 15.22 MB |

6,065 steps at 98.9ms/step. Sliding window eval: 156s (under 10 min eval budget).

## quantization details

- `.mlp.` and `.attn.` 2D weights: int6 per-row with STE QAT (54 tensors)
- `tok_emb.weight`: fp16 passthrough
- Small/control tensors: fp16/fp32 passthrough (38 tensors)
- Compression: zstd level 22
- Quant penalty: 0.001 BPB

## files

- `train_gpt.py` — training script
- `train.log` — full 8xH100 run log (seed 1337)
- `submission.json`
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"author": "Abhishek Gahlot",
"github_id": "abhishekgahlot2",
"name": "Int6 Mixed Quant + MLP 3x + STE QAT + NorMuon + Sliding Window",
"blurb": "Mixed int6/int8 quantization with STE fake-int6 QAT, 3x MLP expansion (21.8M params), NorMuon optimizer, SWA checkpoint averaging, sliding window eval (stride=64), seq_len=2048, batch=786K, grad_clip=0.3, Muon momentum=0.99, zstd-22.",
"date": "2026-03-20",
"val_loss": 2.00546088,
"val_bpb": 1.18774689,
"sliding_window_val_loss": 1.96972383,
"sliding_window_val_bpb": 1.16658140,
"bytes_total": 15216221,
"bytes_code": 59581
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
====================================================================================================
Running Python 3.12.10 (main, May 21 2025, 23:34:56) [GCC 12.2.0]
Running PyTorch 2.10.0+cu128
Thu Mar 19 22:21:32 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.95.05 Driver Version: 580.95.05 CUDA Version: 13.0 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA H100 80GB HBM3 On | 00000000:04:00.0 Off | 0 |
| N/A 37C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | 0 |
| N/A 33C P0 113W / 700W | 1521MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 2 NVIDIA H100 80GB HBM3 On | 00000000:0B:00.0 Off | 0 |
| N/A 38C P0 114W / 700W | 1521MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 3 NVIDIA H100 80GB HBM3 On | 00000000:0C:00.0 Off | 0 |
| N/A 34C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 4 NVIDIA H100 80GB HBM3 On | 00000000:84:00.0 Off | 0 |
| N/A 38C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 5 NVIDIA H100 80GB HBM3 On | 00000000:85:00.0 Off | 0 |
| N/A 33C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 6 NVIDIA H100 80GB HBM3 On | 00000000:8B:00.0 Off | 0 |
| N/A 35C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 7 NVIDIA H100 80GB HBM3 On | 00000000:8C:00.0 Off | 0 |
| N/A 33C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 1 C /bin/dumb-init 1512MiB |
| 1 N/A N/A 1 C /bin/dumb-init 1512MiB |
| 2 N/A N/A 1 C /bin/dumb-init 1512MiB |
| 3 N/A N/A 1 C /bin/dumb-init 1512MiB |
| 4 N/A N/A 1 C /bin/dumb-init 1512MiB |
| 5 N/A N/A 1 C /bin/dumb-init 1512MiB |
| 6 N/A N/A 1 C /bin/dumb-init 1512MiB |
| 7 N/A N/A 1 C /bin/dumb-init 1512MiB |
+-----------------------------------------------------------------------------------------+

====================================================================================================
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/cache/data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=/cache/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
model_params:21778504
world_size:8 grad_accum_steps:1
sdp_backends:cudnn=False flash=True mem_efficient=False math=False
attention_mode:gqa num_heads:8 num_kv_heads:4
tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02
train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000
seed:1337
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:6.9377 val_bpb:4.1089 train_time:0ms step_avg:0.02ms
step:1/20000 train_loss:6.9380 train_time:250ms step_avg:249.93ms
step:2/20000 train_loss:9.8318 train_time:319ms step_avg:159.53ms
step:3/20000 train_loss:7.9512 train_time:422ms step_avg:140.79ms
step:4/20000 train_loss:6.2761 train_time:520ms step_avg:129.99ms
step:5/20000 train_loss:5.6983 train_time:625ms step_avg:125.06ms
step:6/20000 train_loss:5.6061 train_time:719ms step_avg:119.81ms
step:7/20000 train_loss:5.5712 train_time:821ms step_avg:117.33ms
step:8/20000 train_loss:5.4609 train_time:923ms step_avg:115.43ms
step:9/20000 train_loss:5.3242 train_time:1025ms step_avg:113.86ms
step:10/20000 train_loss:5.1481 train_time:1121ms step_avg:112.07ms
step:200/20000 train_loss:2.3890 train_time:19663ms step_avg:98.31ms
step:400/20000 train_loss:2.4325 train_time:39397ms step_avg:98.49ms
step:600/20000 train_loss:2.3525 train_time:58918ms step_avg:98.20ms
step:800/20000 train_loss:2.2646 train_time:78638ms step_avg:98.30ms
step:1000/20000 train_loss:2.3016 train_time:98112ms step_avg:98.11ms
step:1000/20000 val_loss:2.2527 val_bpb:1.3342 train_time:98153ms step_avg:98.15ms
step:1200/20000 train_loss:2.3703 train_time:117744ms step_avg:98.12ms
step:1400/20000 train_loss:2.1822 train_time:137382ms step_avg:98.13ms
step:1600/20000 train_loss:2.0722 train_time:156876ms step_avg:98.05ms
step:1800/20000 train_loss:2.1593 train_time:176565ms step_avg:98.09ms
step:2000/20000 train_loss:2.0755 train_time:196055ms step_avg:98.03ms
step:2000/20000 val_loss:2.1357 val_bpb:1.2649 train_time:196092ms step_avg:98.05ms
step:2200/20000 train_loss:2.1498 train_time:215848ms step_avg:98.11ms
step:2400/20000 train_loss:2.0737 train_time:235326ms step_avg:98.05ms
step:2600/20000 train_loss:2.1130 train_time:255048ms step_avg:98.10ms
step:2800/20000 train_loss:2.1609 train_time:274792ms step_avg:98.14ms
step:3000/20000 train_loss:2.1654 train_time:294283ms step_avg:98.09ms
step:3000/20000 val_loss:2.0930 val_bpb:1.2396 train_time:294322ms step_avg:98.11ms
step:3200/20000 train_loss:2.1738 train_time:313954ms step_avg:98.11ms
step:3400/20000 train_loss:2.0245 train_time:333512ms step_avg:98.09ms
step:3600/20000 train_loss:2.0916 train_time:353295ms step_avg:98.14ms
step:3800/20000 train_loss:2.0723 train_time:372822ms step_avg:98.11ms
step:4000/20000 train_loss:1.9750 train_time:392571ms step_avg:98.14ms
step:4000/20000 val_loss:2.0617 val_bpb:1.2210 train_time:392603ms step_avg:98.15ms
step:4200/20000 train_loss:2.1511 train_time:412299ms step_avg:98.17ms
step:4400/20000 train_loss:2.0351 train_time:431845ms step_avg:98.15ms
step:4600/20000 train_loss:1.8441 train_time:452043ms step_avg:98.27ms
step:4800/20000 train_loss:2.4327 train_time:471776ms step_avg:98.29ms
step:5000/20000 train_loss:2.1045 train_time:492017ms step_avg:98.40ms
step:5000/20000 val_loss:2.0242 val_bpb:1.1988 train_time:492048ms step_avg:98.41ms
step:5200/20000 train_loss:2.0475 train_time:511935ms step_avg:98.45ms
step:5400/20000 train_loss:2.0575 train_time:532478ms step_avg:98.61ms
step:5600/20000 train_loss:1.9682 train_time:553584ms step_avg:98.85ms
step:5800/20000 train_loss:2.0199 train_time:573326ms step_avg:98.85ms
step:6000/20000 train_loss:1.9683 train_time:593616ms step_avg:98.94ms
step:6000/20000 val_loss:2.0070 val_bpb:1.1887 train_time:593649ms step_avg:98.94ms
step:6065/20000 val_loss:2.0070 val_bpb:1.1887 train_time:599931ms step_avg:98.92ms
stopping_early: wallclock_cap train_time:599931ms step:6065/20000
peak memory allocated: 16729 MiB reserved: 16920 MiB
ema:loading averaged weights (decay=0.998)
ema_val_loss:2.4269 ema_val_bpb:1.4374
ema:reverting to raw final weights (raw_val_bpb:1.1887 raw_val_loss:2.0070)
Serialized model: 86099351 bytes
Code size: 59581 bytes
Total submission size: 86158932 bytes
Serialized model mixed+zstd: 15156640 bytes (payload:22428960 raw_torch:22473115 payload_ratio:3.84x int6:54 int8:0 passthrough:38)
Total submission size mixed+zstd: 15216221 bytes
final_mixed_roundtrip val_loss:2.0055 val_bpb:1.1877 eval_time:1696ms
final_mixed_roundtrip_exact val_loss:2.00546088 val_bpb:1.18774689
final_sliding_window val_loss:1.9697 val_bpb:1.1666 stride:64 eval_time:155988ms
final_sliding_window_exact val_loss:1.96972383 val_bpb:1.16658140
Loading