Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
First pilot run of Mixture of Softmax (MoS) on 1x H100 SXM, 10-minute wallclock.

Configuration:
- Track: `non-record`, 1x H100 SXM, 10 min wallclock
- Layout: `VOCAB_SIZE=1024 NUM_LAYERS=9 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=2`
- MoS: `USE_MOS=1 MOS_K=2 MOS_RANK=64` (low-rank factorization, ~99K extra params)
- Tied embeddings, seed=42

Command:
```bash
RUN_ID=mos_k2_r64_pilot \
DATA_PATH=./data/datasets/fineweb10B_sp1024 \
TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \
VOCAB_SIZE=1024 SEED=42 \
USE_MOS=1 MOS_K=2 MOS_RANK=64 \
MAX_WALLCLOCK_SECONDS=600 \
VAL_LOSS_EVERY=500 TRAIN_LOG_EVERY=100 \
torchrun --standalone --nproc_per_node=1 train_gpt.py
```

Key metrics:
- Stopped at step 1113/20000 (wallclock cap)
- Pre-quant: `val_loss:2.3505 val_bpb:1.3921`
- Post-quant (int8+zlib): `val_loss:2.3523 val_bpb:1.3932`
- Quantization degradation: +0.0011 bpb (minimal)
- Model params: 17,159,240
- Artifact: 12,764,492 bytes int8+zlib (12.8MB, 3.2MB under 16MB cap)
- Code: 63,345 bytes
- Total: 12,827,837 bytes
- Peak memory: 11,012 MiB allocated
- Step avg: 539ms/step on 1x H100

Training curve:
| Step | Train Loss | Val BPB | Time |
|------|-----------|---------|------|
| 0 | 6.93 | 4.11 | 0s |
| 100 | 3.27 | — | 54s |
| 500 | 2.58 | 1.52 | 271s |
| 1000 | 2.40 | 1.40 | 542s |
| 1113 | — | 1.39 | 600s |

Notes:
- Loss still dropping at wallclock stop — model had more to learn
- No TTT/LoRA eval was run (only int8 roundtrip)
- No same-conditions baseline for direct comparison (8xH100 baseline: ~1.2244 bpb at 20K steps)
- 1x H100 = ~1/8 throughput → only 1113 steps vs ~20K on 8xH100
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
{
"author": "billyendson",
"github_id": "User123331",
"name": "MoS K=2 Rank=64 Pilot (1xH100, 10min)",
"blurb": "First pilot of Mixture of Softmax (K=2, low-rank=64) on 1xH100 SXM for 10 minutes. Tests softmax bottleneck breaking with minimal parameter overhead (~99K params, 97KB). Artifact 12.8MB, well under 16MB cap. No TTT eval. Loss still dropping at wallclock stop.",
"date": "2026-03-21T19:48:40Z",
"track": "non-record-unlimited-compute-16mb",
"val_loss": 2.35234121,
"val_bpb": 1.39318897,
"pre_quant_val_loss": 2.3505,
"pre_quant_val_bpb": 1.3921,
"step_stop": 1113,
"wallclock_seconds": 600.423,
"bytes_total": 12827837,
"bytes_model_int8_zlib": 12764492,
"bytes_code": 63345,
"gpu": "1xH100_SXM",
"config": {
"USE_MOS": 1,
"MOS_K": 2,
"MOS_RANK": 64,
"VOCAB_SIZE": 1024,
"NUM_LAYERS": 9,
"MODEL_DIM": 512,
"NUM_HEADS": 8,
"NUM_KV_HEADS": 4,
"MLP_MULT": 2,
"SEED": 42,
"TRAIN_SEQ_LEN": 1024
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
logs/mos_k2_r64_pilot.txt
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
model_params:17159240
world_size:1 grad_accum_steps:8
sdp_backends:cudnn=False flash=True mem_efficient=False math=False
attention_mode:gqa num_heads:8 num_kv_heads:4
tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04
train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000
seed:42
/usr/local/lib/python3.12/dist-packages/torch/_inductor/lowering.py:7242: UserWarning:
Online softmax is disabled on the fly since Inductor decides to
split the reduction. Cut an issue to PyTorch if this is an
important use case and you want to speed it up with online
softmax.

warnings.warn(
/usr/local/lib/python3.12/dist-packages/torch/_inductor/lowering.py:7242: UserWarning:
Online softmax is disabled on the fly since Inductor decides to
split the reduction. Cut an issue to PyTorch if this is an
important use case and you want to speed it up with online
softmax.

warnings.warn(
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
/usr/local/lib/python3.12/dist-packages/torch/_inductor/lowering.py:7242: UserWarning:
Online softmax is disabled on the fly since Inductor decides to
split the reduction. Cut an issue to PyTorch if this is an
important use case and you want to speed it up with online
softmax.

warnings.warn(
/usr/local/lib/python3.12/dist-packages/torch/_inductor/lowering.py:7242: UserWarning:
Online softmax is disabled on the fly since Inductor decides to
split the reduction. Cut an issue to PyTorch if this is an
important use case and you want to speed it up with online
softmax.

warnings.warn(
step:0/20000 val_loss:6.9314 val_bpb:4.1052 train_time:0ms step_avg:0.01ms
step:1/20000 train_loss:6.9314 train_time:581ms step_avg:581.11ms
step:2/20000 train_loss:6.8515 train_time:1295ms step_avg:647.47ms
step:3/20000 train_loss:5.8655 train_time:1996ms step_avg:665.19ms
step:4/20000 train_loss:5.4250 train_time:2793ms step_avg:698.33ms
step:5/20000 train_loss:5.0728 train_time:3413ms step_avg:682.51ms
step:6/20000 train_loss:4.9797 train_time:4016ms step_avg:669.27ms
step:7/20000 train_loss:4.8555 train_time:4676ms step_avg:668.03ms
step:8/20000 train_loss:4.7612 train_time:5341ms step_avg:667.67ms
step:9/20000 train_loss:4.6900 train_time:5990ms step_avg:665.54ms
step:10/20000 train_loss:4.7029 train_time:6682ms step_avg:668.23ms
step:100/20000 train_loss:3.2746 train_time:54475ms step_avg:544.75ms
step:200/20000 train_loss:2.8511 train_time:108479ms step_avg:542.40ms
step:300/20000 train_loss:2.7046 train_time:162973ms step_avg:543.24ms
step:400/20000 train_loss:2.4804 train_time:217390ms step_avg:543.47ms
step:500/20000 train_loss:2.5755 train_time:271183ms step_avg:542.37ms
step:500/20000 val_loss:2.5703 val_bpb:1.5223 train_time:271193ms step_avg:542.39ms
step:600/20000 train_loss:2.5630 train_time:324786ms step_avg:541.31ms
step:700/20000 train_loss:2.5112 train_time:378359ms step_avg:540.51ms
step:800/20000 train_loss:2.3957 train_time:432963ms step_avg:541.20ms
step:900/20000 train_loss:2.4135 train_time:487589ms step_avg:541.77ms
step:1000/20000 train_loss:2.4031 train_time:542181ms step_avg:542.18ms
step:1000/20000 val_loss:2.3696 val_bpb:1.4034 train_time:542248ms step_avg:542.25ms
step:1100/20000 train_loss:2.3186 train_time:594115ms step_avg:540.10ms
step:1113/20000 val_loss:2.3505 val_bpb:1.3921 train_time:600423ms step_avg:539.46ms
stopping_early: wallclock_cap train_time:600423ms step:1113/20000
peak memory allocated: 11012 MiB reserved: 11320 MiB
Serialized model: 67623386 bytes
Code size: 63345 bytes
Total submission size: 67686731 bytes
Serialized model int8+zlib: 12764492 bytes (payload:17377568 raw_torch:17423635 payload_ratio:3.89x)
Total submission size int8+zlib: 12827837 bytes
final_int8_zlib_roundtrip val_loss:2.3523 val_bpb:1.3932 eval_time:11887ms
final_int8_zlib_roundtrip_exact val_loss:2.35234121 val_bpb:1.39318897
Loading