-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Submission/fastattn mtp dr #1691
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1c74934
17c6296
a998ec2
482323d
162c862
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,139 @@ | ||||||
| # Baseline + MTP (10 min, 16 MB) | ||||||
|
|
||||||
| Proven 43 ms/step baseline (9 layers x 512 dim, GQA 8/4) with a single | ||||||
| surgical addition: **multi-token prediction (MTP) auxiliary loss**. | ||||||
|
|
||||||
| ## Why this plan | ||||||
|
|
||||||
| The previous FastAttn+MTP+DepthRec record hit `val_bpb = 1.27` because | ||||||
| `NUM_REPS=2` halved the step budget (85 ms vs 43 ms) and the wider/shallower | ||||||
| shape hurt more than it helped. Rolling back to the proven baseline shape | ||||||
| and keeping only MTP should match or slightly beat the 1.22 baseline. | ||||||
|
|
||||||
| ## What changed vs. baseline | ||||||
|
|
||||||
| | # | Addition | Expected effect | | ||||||
| |---|----------|-----------------| | ||||||
| | 1 | **Multi-token prediction** -- auxiliary CE loss on token `t+2` via a small `CastedLinear(d,d)` + RMSNorm head, tied to the embedding table | +0.02 to 0.05 BPB improvement (DeepSeek-V3 / MoC), free at eval | | ||||||
|
|
||||||
| Everything else is verbatim from the baseline: Muon optimizer, SDPA/FlashAttn, | ||||||
| U-Net skips with per-feature skip weights, tied embeddings, int8 + zlib GPTQ | ||||||
| quantisation, logit softcap, same LR schedule. | ||||||
|
|
||||||
| ## Architecture | ||||||
|
|
||||||
| - **Shape**: 9 layers x 512 dim, 8 heads / 4 KV heads (GQA), MLP x2 (ReLU^2). | ||||||
| - **Vocab**: 1024 SentencePiece BPE, tied embed/unembed. | ||||||
| - **Context**: 1024 tokens. | ||||||
| - **MTP head**: one `CastedLinear(512,512)` + `RMSNorm`, reuses tied embed for | ||||||
| the vocabulary projection. Init to zero so training starts identical to | ||||||
| baseline. Aux loss weight 0.3, gated on `self.training` so eval BPB is clean. | ||||||
| - **No depth recurrence** (`NUM_REPS=1`), no `rep_gates`. | ||||||
|
|
||||||
| ## Training recipe | ||||||
|
|
||||||
| - **Steps**: 20000 target (wall-clock cap at 590 s normally stops well before). | ||||||
| - **Batch**: 524288 tokens/step, seq_len 1024. | ||||||
| - **Schedule**: 20-step warmup, flat, 1200-step linear cooldown. | ||||||
| - **Optimizers**: Muon (matrices, lr=0.04), Adam (embeddings lr=0.05, scalars | ||||||
| lr=0.04). | ||||||
| - **No mid-training validation** (would eat the wall-clock budget). | ||||||
|
|
||||||
| ## Expected outcome | ||||||
|
|
||||||
| - **Target: `val_bpb` <= 1.22** (match baseline) with upside to ~1.18 from MTP. | ||||||
| - Definitely will not reach 1.02 on its own. This run is a sanity baseline | ||||||
| so subsequent experiments have a known-good reference point. | ||||||
|
|
||||||
| ## Reproduce | ||||||
|
|
||||||
| ```bash | ||||||
| # one-time | ||||||
| pip install brotli sentencepiece -q | ||||||
| python3 data/cached_challenge_fineweb.py --variant sp1024 | ||||||
|
|
||||||
| # full 8xH100 run | ||||||
| SEED=42 bash records/track_10min_16mb/2026-04-17_Baseline_MTP/run_leaderboard_8xh100.sh 2>&1 | tee train_seed42.log | ||||||
|
|
||||||
| # smoke (1 GPU) | ||||||
| bash records/track_10min_16mb/2026-04-17_Baseline_MTP/run_smoke_1gpu.sh | ||||||
| ``` | ||||||
|
|
||||||
| ## Files | ||||||
|
|
||||||
| - `train_gpt.py` - proven baseline + MTP aux loss head | ||||||
| - `run_leaderboard_8xh100.sh` - production launcher | ||||||
| - `run_smoke_1gpu.sh` - sanity check | ||||||
| - `submission.json` - leaderboard metadata | ||||||
| # FastAttn + MTP + Depth-Recurrence (10 min, 16 MB) | ||||||
|
|
||||||
| A pragmatic fork of the proven 43ms/step baseline with three targeted upgrades | ||||||
| that are known to improve per-step quality without hurting throughput. | ||||||
|
|
||||||
| ## Why this plan | ||||||
|
|
||||||
| The baseline (9L × 512d, pure SDPA GQA attention) hits **43 ms/step** and | ||||||
| **val_bpb = 1.2244** in 10 minutes. The previous `GatedMixer` attempt ran at | ||||||
| 220 ms/step — 5x slower — which cannot reach competitive scores regardless of | ||||||
| tuning. This record throws out exotic mixers and instead layers three | ||||||
| well-established techniques on the fast baseline. | ||||||
|
|
||||||
| ## Upgrades vs. baseline | ||||||
|
|
||||||
| | # | Upgrade | Effect | | ||||||
| |---|---------|--------| | ||||||
| | 1 | **Depth recurrence** — run block stack `NUM_REPS=2` times, weights shared | +100% compute per step, 0 extra params. Works well under a parameter-size cap. | | ||||||
| | 2 | **Multi-token prediction (MTP)** — aux CE loss on token `t+2` via a tiny projection | +3-5% BPB reduction at same tokens (DeepSeek-V3, MoC). Disabled at eval. | | ||||||
| | 3 | **Wider model (576 vs 512), fewer layers (7 vs 9)** | Same param budget, more attention heads see the data, better GQA layout (8/2). | | ||||||
|
|
||||||
| Everything else (Muon, SDPA/FlashAttn, U-Net skips, tied embeddings, int8+zlib | ||||||
| GPTQ, logit softcap) is inherited verbatim from the baseline. | ||||||
|
|
||||||
| ## Architecture | ||||||
|
|
||||||
| - **Layout**: 7 layers × 576 dim, 8 heads / 2 KV heads, MLP×2 (ReLU²). | ||||||
| - **Effective depth**: 14 (7 physical × 2 reps). | ||||||
| - **Vocab**: 1024 SentencePiece BPE, tied embed/unembed. | ||||||
| - **Context**: 1024 tokens. | ||||||
| - **Recurrence gate**: per-rep learned vector, init 0 → training learns to lean | ||||||
| on recurrence gradually. At init the model behaves identically to the | ||||||
| baseline U-Net. | ||||||
| - **MTP**: single `CastedLinear(d,d)` → RMSNorm → tied head, predicting | ||||||
| `target[t+1]` from hidden at `t`. Aux loss weight 0.3, training only. | ||||||
|
|
||||||
| ## Training recipe | ||||||
|
|
||||||
| - **Steps**: 12000 target (wall-clock cap at 590s usually stops earlier). | ||||||
| - **Batch**: 524288 tokens/step, seq_len 1024. | ||||||
| - **Schedule**: 30-step warmup, flat, 1500-step linear cooldown. | ||||||
| - **Optimizers**: Muon (matrices, lr=0.04), Adam (embeddings lr=0.05, scalars | ||||||
| lr=0.04). | ||||||
| - **No mid-training validation** — eats the wall-clock budget. Final eval runs | ||||||
| once post-training over full validation split. | ||||||
|
|
||||||
| ## Expected outcome | ||||||
|
|
||||||
| - **Target: `val_bpb` ≈ 1.08–1.15**, a meaningful jump from baseline 1.22. | ||||||
| - Reaching 1.02 likely requires 2-3 more iterations on top (span corruption, | ||||||
| 4x recurrence, maybe distillation). This is a solid foundation. | ||||||
|
|
||||||
| ## Reproduce | ||||||
|
|
||||||
| ```bash | ||||||
| # one-time | ||||||
| pip install brotli sentencepiece -q | ||||||
|
||||||
| pip install brotli sentencepiece -q | |
| pip install sentencepiece -q |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,108 @@ | ||
| #!/usr/bin/env bash | ||
| # Baseline + MTP leaderboard run (8xH100, 10-min cap). | ||
| # | ||
| # Proven 9L x 512d baseline with multi-token prediction (MTP) auxiliary loss. | ||
| # No depth recurrence (NUM_REPS=1). | ||
| set -euo pipefail | ||
|
|
||
| HERE="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" | ||
| REPO_ROOT="${HERE}/../../.." | ||
| TRAIN="${HERE}/train_gpt.py" | ||
|
|
||
| : "${DATA_PATH:=${REPO_ROOT}/data/datasets/fineweb10B_sp1024}" | ||
| : "${TOKENIZER_PATH:=${REPO_ROOT}/data/tokenizers/fineweb_1024_bpe.model}" | ||
| : "${SEED:=42}" | ||
| : "${RUN_ID:=baseline_mtp_$(date +%s)}" | ||
|
|
||
| # Auto-download dataset + tokenizer if not already present. | ||
| if [[ ! -f "${TOKENIZER_PATH}" || ! -d "${DATA_PATH}" ]]; then | ||
| echo "[setup] Data/tokenizer not found - running download script..." | ||
| python3 "${REPO_ROOT}/data/cached_challenge_fineweb.py" --variant sp1024 | ||
| fi | ||
|
|
||
| export NCCL_IB_DISABLE=1 | ||
| export TORCH_CUDNN_V8_API_ENABLED=1 | ||
| export CUDA_DEVICE_MAX_CONNECTIONS=1 | ||
| export OMP_NUM_THREADS=1 | ||
|
|
||
| DATA_PATH="${DATA_PATH}" \ | ||
| TOKENIZER_PATH="${TOKENIZER_PATH}" \ | ||
| SEED="${SEED}" \ | ||
| RUN_ID="${RUN_ID}" \ | ||
| MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-590}" \ | ||
| ITERATIONS="${ITERATIONS:-20000}" \ | ||
| WARMUP_STEPS="${WARMUP_STEPS:-20}" \ | ||
| WARMDOWN_ITERS="${WARMDOWN_ITERS:-1200}" \ | ||
| VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-0}" \ | ||
| TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-100}" \ | ||
| TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-1024}" \ | ||
| TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" \ | ||
| VOCAB_SIZE="${VOCAB_SIZE:-1024}" \ | ||
| NUM_LAYERS="${NUM_LAYERS:-9}" \ | ||
| MODEL_DIM="${MODEL_DIM:-512}" \ | ||
| NUM_HEADS="${NUM_HEADS:-8}" \ | ||
| NUM_KV_HEADS="${NUM_KV_HEADS:-4}" \ | ||
| MLP_MULT="${MLP_MULT:-2}" \ | ||
| NUM_REPS="${NUM_REPS:-1}" \ | ||
| MTP_WEIGHT="${MTP_WEIGHT:-0.3}" \ | ||
| TIE_EMBEDDINGS="${TIE_EMBEDDINGS:-1}" \ | ||
| TIED_EMBED_LR="${TIED_EMBED_LR:-0.05}" \ | ||
| MATRIX_LR="${MATRIX_LR:-0.04}" \ | ||
| SCALAR_LR="${SCALAR_LR:-0.04}" \ | ||
| QK_GAIN_INIT="${QK_GAIN_INIT:-1.5}" \ | ||
| torchrun --standalone --nproc_per_node=8 "${TRAIN}" | ||
| #!/usr/bin/env bash | ||
| # FastAttn + MTP + Depth-Recurrence leaderboard run (8xH100, 10-min cap). | ||
| # | ||
| # Strategy: fork of the proven 43ms/step baseline, with three surgical upgrades: | ||
| # (1) depth recurrence (weights shared across NUM_REPS passes) | ||
| # (2) multi-token prediction (auxiliary loss at t+2) | ||
| # (3) slightly bigger width (576 vs 512) since DR is param-free | ||
| set -euo pipefail | ||
|
|
||
| HERE="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" | ||
| REPO_ROOT="${HERE}/../../.." | ||
| TRAIN="${HERE}/train_gpt.py" | ||
|
|
||
| : "${DATA_PATH:=${REPO_ROOT}/data/datasets/fineweb10B_sp1024}" | ||
| : "${TOKENIZER_PATH:=${REPO_ROOT}/data/tokenizers/fineweb_1024_bpe.model}" | ||
| : "${SEED:=42}" | ||
| : "${RUN_ID:=fastattn_mtp_dr_$(date +%s)}" | ||
|
|
||
| # Auto-download dataset + tokenizer if not already present | ||
| if [[ ! -f "${TOKENIZER_PATH}" || ! -d "${DATA_PATH}" ]]; then | ||
| echo "[setup] Data/tokenizer not found – running download script..." | ||
| python3 "${REPO_ROOT}/data/cached_challenge_fineweb.py" --variant sp1024 | ||
| fi | ||
|
|
||
| export NCCL_IB_DISABLE=1 | ||
| export TORCH_CUDNN_V8_API_ENABLED=1 | ||
| export CUDA_DEVICE_MAX_CONNECTIONS=1 | ||
| export OMP_NUM_THREADS=1 | ||
|
|
||
| DATA_PATH="${DATA_PATH}" \ | ||
| TOKENIZER_PATH="${TOKENIZER_PATH}" \ | ||
| SEED="${SEED}" \ | ||
| RUN_ID="${RUN_ID}" \ | ||
| MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-590}" \ | ||
| ITERATIONS="${ITERATIONS:-12000}" \ | ||
| WARMUP_STEPS="${WARMUP_STEPS:-30}" \ | ||
| WARMDOWN_ITERS="${WARMDOWN_ITERS:-1500}" \ | ||
| VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-0}" \ | ||
| TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-100}" \ | ||
| TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-1024}" \ | ||
| TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" \ | ||
| VOCAB_SIZE="${VOCAB_SIZE:-1024}" \ | ||
| NUM_LAYERS="${NUM_LAYERS:-7}" \ | ||
| MODEL_DIM="${MODEL_DIM:-576}" \ | ||
| NUM_HEADS="${NUM_HEADS:-8}" \ | ||
| NUM_KV_HEADS="${NUM_KV_HEADS:-2}" \ | ||
| MLP_MULT="${MLP_MULT:-2}" \ | ||
| NUM_REPS="${NUM_REPS:-2}" \ | ||
| MTP_WEIGHT="${MTP_WEIGHT:-0.3}" \ | ||
| TIE_EMBEDDINGS="${TIE_EMBEDDINGS:-1}" \ | ||
| TIED_EMBED_LR="${TIED_EMBED_LR:-0.05}" \ | ||
| MATRIX_LR="${MATRIX_LR:-0.04}" \ | ||
| SCALAR_LR="${SCALAR_LR:-0.04}" \ | ||
| QK_GAIN_INIT="${QK_GAIN_INIT:-1.5}" \ | ||
| torchrun --standalone --nproc_per_node=8 "${TRAIN}" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| #!/usr/bin/env bash | ||
| # Single-GPU smoke test for Baseline + MTP. | ||
| set -euo pipefail | ||
|
|
||
| HERE="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" | ||
| REPO_ROOT="${HERE}/../../.." | ||
| TRAIN="${HERE}/train_gpt.py" | ||
|
|
||
| : "${DATA_PATH:=${REPO_ROOT}/data/datasets/fineweb10B_sp1024}" | ||
| : "${TOKENIZER_PATH:=${REPO_ROOT}/data/tokenizers/fineweb_1024_bpe.model}" | ||
|
|
||
| # Auto-download dataset + tokenizer if not already present. | ||
| if [[ ! -f "${TOKENIZER_PATH}" || ! -d "${DATA_PATH}" ]]; then | ||
| echo "[setup] Data/tokenizer not found - running download script..." | ||
| python3 "${REPO_ROOT}/data/cached_challenge_fineweb.py" --variant sp1024 | ||
| fi | ||
|
|
||
| export NCCL_IB_DISABLE=1 | ||
| export TORCH_CUDNN_V8_API_ENABLED=1 | ||
| export CUDA_DEVICE_MAX_CONNECTIONS=1 | ||
| export OMP_NUM_THREADS=1 | ||
|
|
||
| DATA_PATH="${DATA_PATH}" \ | ||
| TOKENIZER_PATH="${TOKENIZER_PATH}" \ | ||
| RUN_ID="smoke_$(date +%s)" \ | ||
| MAX_WALLCLOCK_SECONDS=120 \ | ||
| ITERATIONS=200 \ | ||
| WARMUP_STEPS=10 \ | ||
| WARMDOWN_ITERS=50 \ | ||
| VAL_LOSS_EVERY=0 \ | ||
| TRAIN_LOG_EVERY=20 \ | ||
| TRAIN_SEQ_LEN=1024 \ | ||
| TRAIN_BATCH_TOKENS=65536 \ | ||
| VOCAB_SIZE=1024 \ | ||
| NUM_LAYERS=4 \ | ||
| MODEL_DIM=256 \ | ||
| NUM_HEADS=4 \ | ||
| NUM_KV_HEADS=2 \ | ||
| MLP_MULT=2 \ | ||
| NUM_REPS=1 \ | ||
| MTP_WEIGHT=0.3 \ | ||
| TIE_EMBEDDINGS=1 \ | ||
| torchrun --standalone --nproc_per_node=1 "${TRAIN}" |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,26 @@ | ||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||
| "track": "track_10min_16mb", | ||||||||||||||||||||||||||||||||||||||||||||||
| "name": "Baseline_MTP", | ||||||||||||||||||||||||||||||||||||||||||||||
| "date": "2026-04-17", | ||||||||||||||||||||||||||||||||||||||||||||||
| "val_loss": null, | ||||||||||||||||||||||||||||||||||||||||||||||
| "val_bpb": null, | ||||||||||||||||||||||||||||||||||||||||||||||
| "bytes_total": null, | ||||||||||||||||||||||||||||||||||||||||||||||
| "bytes_code": null, | ||||||||||||||||||||||||||||||||||||||||||||||
| "bytes_model": null, | ||||||||||||||||||||||||||||||||||||||||||||||
| "training_time_seconds": 600, | ||||||||||||||||||||||||||||||||||||||||||||||
| "gpus": "8xH100", | ||||||||||||||||||||||||||||||||||||||||||||||
| "notes": "Proven 9L x 512d baseline with multi-token prediction (MTP) auxiliary loss only. No depth recurrence. MTP weight 0.3, disabled at eval." | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||
| "track": "track_10min_16mb", | ||||||||||||||||||||||||||||||||||||||||||||||
| "name": "FastAttn_MTP_DepthRec", | ||||||||||||||||||||||||||||||||||||||||||||||
| "date": "2026-04-17", | ||||||||||||||||||||||||||||||||||||||||||||||
| "val_loss": null, | ||||||||||||||||||||||||||||||||||||||||||||||
| "val_bpb": null, | ||||||||||||||||||||||||||||||||||||||||||||||
| "bytes_total": null, | ||||||||||||||||||||||||||||||||||||||||||||||
| "bytes_code": null, | ||||||||||||||||||||||||||||||||||||||||||||||
| "bytes_model": null, | ||||||||||||||||||||||||||||||||||||||||||||||
| "training_time_seconds": 600, | ||||||||||||||||||||||||||||||||||||||||||||||
| "gpus": "8xH100", | ||||||||||||||||||||||||||||||||||||||||||||||
| "notes": "Fork of proven baseline. Adds (1) depth recurrence NUM_REPS=2, (2) multi-token prediction MTP_WEIGHT=0.3, (3) width 576 (vs 512). 7 physical layers, 14 effective via recurrence." | ||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+15
to
+25
|
||||||||||||||||||||||||||||||||||||||||||||||
| "track": "track_10min_16mb", | |
| "name": "FastAttn_MTP_DepthRec", | |
| "date": "2026-04-17", | |
| "val_loss": null, | |
| "val_bpb": null, | |
| "bytes_total": null, | |
| "bytes_code": null, | |
| "bytes_model": null, | |
| "training_time_seconds": 600, | |
| "gpus": "8xH100", | |
| "notes": "Fork of proven baseline. Adds (1) depth recurrence NUM_REPS=2, (2) multi-token prediction MTP_WEIGHT=0.3, (3) width 576 (vs 512). 7 physical layers, 14 effective via recurrence." | |
| "author": "", | |
| "github_id": "", | |
| "date": "2026-04-17", | |
| "blurb": "Fork of proven baseline. Adds (1) depth recurrence NUM_REPS=2, (2) multi-token prediction MTP_WEIGHT=0.3, (3) width 576 (vs 512). 7 physical layers, 14 effective via recurrence.", | |
| "val_loss": null, | |
| "val_bpb": null, | |
| "bytes_total": null, | |
| "bytes_code": null, | |
| "bytes_model": null, | |
| "training_time_seconds": 600, | |
| "gpus": "8xH100" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The README claims the quantization/export method includes "int8+zlib GPTQ", but
train_gpt.pyimplements a simple per-row/per-tensor int8 quantization with saved scales (no GPTQ optimization step). This wording is misleading; consider renaming it to match the actual implementation (e.g., "int8 per-row + zlib"), or documenting GPTQ only if it’s truly used.