Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# FastAttn + MTP + Depth-Recurrence (10 min, 16 MB)

A pragmatic fork of the proven 43ms/step baseline with three targeted upgrades
that are known to improve per-step quality without hurting throughput.

## Why this plan

The baseline (9L × 512d, pure SDPA GQA attention) hits **43 ms/step** and
**val_bpb = 1.2244** in 10 minutes. The previous `GatedMixer` attempt ran at
220 ms/step — 5x slower — which cannot reach competitive scores regardless of
tuning. This record throws out exotic mixers and instead layers three
well-established techniques on the fast baseline.

## Upgrades vs. baseline

| # | Upgrade | Effect |
|---|---------|--------|
| 1 | **Depth recurrence** — run block stack `NUM_REPS=2` times, weights shared | +100% compute per step, 0 extra params. Works well under a parameter-size cap. |
| 2 | **Multi-token prediction (MTP)** — aux CE loss on token `t+2` via a tiny projection | +3-5% BPB reduction at same tokens (DeepSeek-V3, MoC). Disabled at eval. |
| 3 | **Wider model (576 vs 512), fewer layers (7 vs 9)** | Same param budget, more attention heads see the data, better GQA layout (8/2). |

Everything else (Muon, SDPA/FlashAttn, U-Net skips, tied embeddings, int8+zlib
GPTQ, logit softcap) is inherited verbatim from the baseline.

## Architecture

- **Layout**: 7 layers × 576 dim, 8 heads / 2 KV heads, MLP×2 (ReLU²).
- **Effective depth**: 14 (7 physical × 2 reps).
- **Vocab**: 1024 SentencePiece BPE, tied embed/unembed.
- **Context**: 1024 tokens.
- **Recurrence gate**: per-rep learned vector, init 0 → training learns to lean
on recurrence gradually. At init the model behaves identically to the
baseline U-Net.
- **MTP**: single `CastedLinear(d,d)` → RMSNorm → tied head, predicting
`target[t+1]` from hidden at `t`. Aux loss weight 0.3, training only.

## Training recipe

- **Steps**: 12000 target (wall-clock cap at 590s usually stops earlier).
- **Batch**: 524288 tokens/step, seq_len 1024.
- **Schedule**: 30-step warmup, flat, 1500-step linear cooldown.
- **Optimizers**: Muon (matrices, lr=0.04), Adam (embeddings lr=0.05, scalars
lr=0.04).
- **No mid-training validation** — eats the wall-clock budget. Final eval runs
once post-training over full validation split.

## Expected outcome

- **Target: `val_bpb` ≈ 1.08–1.15**, a meaningful jump from baseline 1.22.
- Reaching 1.02 likely requires 2-3 more iterations on top (span corruption,
4x recurrence, maybe distillation). This is a solid foundation.

## Reproduce

```bash
# one-time
pip install brotli sentencepiece -q
python3 data/cached_challenge_fineweb.py --variant sp1024

# full 8xH100 run
SEED=42 bash records/track_10min_16mb/2026-04-17_FastAttn_MTP_DepthRec/run_leaderboard_8xh100.sh 2>&1 | tee train_seed42.log

# smoke (1 GPU, 2 min)
bash records/track_10min_16mb/2026-04-17_FastAttn_MTP_DepthRec/run_smoke_1gpu.sh
```

## Files

- `train_gpt.py` — forked from root baseline + depth-recurrence + MTP
- `run_leaderboard_8xh100.sh` — production launcher
- `run_smoke_1gpu.sh` — sanity check
- `submission.json` — leaderboard metadata (val_bpb filled after run)
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/usr/bin/env bash
# FastAttn + MTP + Depth-Recurrence leaderboard run (8xH100, 10-min cap).
#
# Strategy: fork of the proven 43ms/step baseline, with three surgical upgrades:
# (1) depth recurrence (weights shared across NUM_REPS passes)
# (2) multi-token prediction (auxiliary loss at t+2)
# (3) slightly bigger width (576 vs 512) since DR is param-free
set -euo pipefail

: "${DATA_PATH:=./data/datasets/fineweb10B_sp1024}"
: "${TOKENIZER_PATH:=./data/tokenizers/fineweb_1024_bpe.model}"
: "${SEED:=42}"
: "${RUN_ID:=fastattn_mtp_dr_$(date +%s)}"

export NCCL_IB_DISABLE=1
export TORCH_CUDNN_V8_API_ENABLED=1
export CUDA_DEVICE_MAX_CONNECTIONS=1
export OMP_NUM_THREADS=1

HERE="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
TRAIN="${HERE}/train_gpt.py"

DATA_PATH="${DATA_PATH}" \
TOKENIZER_PATH="${TOKENIZER_PATH}" \
SEED="${SEED}" \
RUN_ID="${RUN_ID}" \
MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-590}" \
ITERATIONS="${ITERATIONS:-12000}" \
WARMUP_STEPS="${WARMUP_STEPS:-30}" \
WARMDOWN_ITERS="${WARMDOWN_ITERS:-1500}" \
VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-0}" \
TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-100}" \
TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-1024}" \
TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}" \
VOCAB_SIZE="${VOCAB_SIZE:-1024}" \
NUM_LAYERS="${NUM_LAYERS:-7}" \
MODEL_DIM="${MODEL_DIM:-576}" \
NUM_HEADS="${NUM_HEADS:-8}" \
NUM_KV_HEADS="${NUM_KV_HEADS:-2}" \
MLP_MULT="${MLP_MULT:-2}" \
NUM_REPS="${NUM_REPS:-2}" \
MTP_WEIGHT="${MTP_WEIGHT:-0.3}" \
TIE_EMBEDDINGS="${TIE_EMBEDDINGS:-1}" \
TIED_EMBED_LR="${TIED_EMBED_LR:-0.05}" \
MATRIX_LR="${MATRIX_LR:-0.04}" \
SCALAR_LR="${SCALAR_LR:-0.04}" \
QK_GAIN_INIT="${QK_GAIN_INIT:-1.5}" \
torchrun --standalone --nproc_per_node=8 "${TRAIN}"
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/usr/bin/env bash
# Single-GPU smoke test to verify the new arch compiles and trains a few steps.
set -euo pipefail

: "${DATA_PATH:=./data/datasets/fineweb10B_sp1024}"
: "${TOKENIZER_PATH:=./data/tokenizers/fineweb_1024_bpe.model}"

export NCCL_IB_DISABLE=1
export TORCH_CUDNN_V8_API_ENABLED=1
export CUDA_DEVICE_MAX_CONNECTIONS=1
export OMP_NUM_THREADS=1

HERE="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
TRAIN="${HERE}/train_gpt.py"

DATA_PATH="${DATA_PATH}" \
TOKENIZER_PATH="${TOKENIZER_PATH}" \
RUN_ID="smoke_$(date +%s)" \
MAX_WALLCLOCK_SECONDS=120 \
ITERATIONS=200 \
WARMUP_STEPS=10 \
WARMDOWN_ITERS=50 \
VAL_LOSS_EVERY=0 \
TRAIN_LOG_EVERY=20 \
TRAIN_SEQ_LEN=1024 \
TRAIN_BATCH_TOKENS=65536 \
VOCAB_SIZE=1024 \
NUM_LAYERS=4 \
MODEL_DIM=256 \
NUM_HEADS=4 \
NUM_KV_HEADS=1 \
MLP_MULT=2 \
NUM_REPS=2 \
MTP_WEIGHT=0.3 \
TIE_EMBEDDINGS=1 \
torchrun --standalone --nproc_per_node=1 "${TRAIN}"
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"track": "track_10min_16mb",
"name": "FastAttn_MTP_DepthRec",
"date": "2026-04-17",
"val_loss": null,
"val_bpb": null,
"bytes_total": null,
"bytes_code": null,
"bytes_model": null,
"training_time_seconds": 600,
"gpus": "8xH100",
"notes": "Fork of proven baseline. Adds (1) depth recurrence NUM_REPS=2, (2) multi-token prediction MTP_WEIGHT=0.3, (3) width 576 (vs 512). 7 physical layers, 14 effective via recurrence."
}
Loading
Loading