|
| 1 | +#!/usr/bin/env bash |
| 2 | +set -euo pipefail |
| 3 | + |
| 4 | +# A/B/C TEST: Uniform vs Progressive MLP vs Progressive MLP + Fat BigramHash |
| 5 | +# |
| 6 | +# All runs use identical #180 config except: |
| 7 | +# A: Uniform MLP 3.0× (control — exact #180) |
| 8 | +# B: Progressive MLP 1.5×→4.5× (same total params) |
| 9 | +# C: Progressive MLP 1.5×→4.5× + BigramHash 16384 buckets, dim 192 |
| 10 | +# (compensate thin early layers with richer bigram input) |
| 11 | +# |
| 12 | +# 1xGPU, comparing BPB delta only. ~30 min total. |
| 13 | + |
| 14 | +export DATA_PATH="${DATA_PATH:-./data/datasets/fineweb10B_sp1024/}" |
| 15 | +export TOKENIZER_PATH="${TOKENIZER_PATH:-./data/tokenizers/fineweb_1024_bpe.model}" |
| 16 | + |
| 17 | +LOGDIR="logs/abc_progressive_mlp_$(date +%Y%m%d_%H%M%S)" |
| 18 | +mkdir -p "$LOGDIR" |
| 19 | + |
| 20 | +COMMON="NUM_LAYERS=10 \ |
| 21 | +MODEL_DIM=512 \ |
| 22 | +NUM_HEADS=8 \ |
| 23 | +NUM_KV_HEADS=4 \ |
| 24 | +MLP_MULT=3.0 \ |
| 25 | +TIE_EMBEDDINGS=1 \ |
| 26 | +VOCAB_SIZE=1024 \ |
| 27 | +TRAIN_BATCH_TOKENS=786432 \ |
| 28 | +TRAIN_SEQ_LEN=2048 \ |
| 29 | +ITERATIONS=20000 \ |
| 30 | +WARMDOWN_ITERS=3000 \ |
| 31 | +WARMUP_STEPS=20 \ |
| 32 | +MAX_WALLCLOCK_SECONDS=600 \ |
| 33 | +TRAIN_LOG_EVERY=100 \ |
| 34 | +VAL_LOSS_EVERY=500 \ |
| 35 | +WEIGHT_DECAY=0.04 \ |
| 36 | +MATRIX_LR=0.02 \ |
| 37 | +SCALAR_LR=0.02 \ |
| 38 | +TIED_EMBED_LR=0.03 \ |
| 39 | +MUON_MOMENTUM=0.99 \ |
| 40 | +MUON_MOMENTUM_WARMUP_START=0.92 \ |
| 41 | +MUON_MOMENTUM_WARMUP_STEPS=1500 \ |
| 42 | +GRAD_CLIP_NORM=0.3 \ |
| 43 | +EVAL_STRIDE=64 \ |
| 44 | +SWA_ENABLED=1 \ |
| 45 | +SWA_START_FRAC=0.4 \ |
| 46 | +SWA_EVERY=50 \ |
| 47 | +SEED=42 \ |
| 48 | +NCCL_IB_DISABLE=1" |
| 49 | + |
| 50 | +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" |
| 51 | + |
| 52 | +PROGRESSIVE="1.5,1.83,2.17,2.5,2.83,3.17,3.5,3.83,4.17,4.5" |
| 53 | + |
| 54 | +echo "============================================" |
| 55 | +echo " A/B/C Test: MLP Distribution + BigramHash" |
| 56 | +echo " 1xGPU — comparing delta only" |
| 57 | +echo " Logs: $LOGDIR" |
| 58 | +echo "============================================" |
| 59 | + |
| 60 | +# --- RUN A: Uniform 3.0× (control) --- |
| 61 | +echo "" |
| 62 | +echo "=== [A] UNIFORM MLP 3.0x + BigramHash 10240/128 ===" |
| 63 | +env $COMMON BIGRAM_VOCAB_SIZE=10240 BIGRAM_DIM=128 MLP_SCHEDULE="" RUN_ID="abc_A_uniform" \ |
| 64 | + python "$SCRIPT_DIR/train_gpt.py" \ |
| 65 | + 2>&1 | tee "$LOGDIR/A_uniform.log" |
| 66 | + |
| 67 | +# --- RUN B: Progressive 1.5→4.5× --- |
| 68 | +echo "" |
| 69 | +echo "=== [B] PROGRESSIVE MLP 1.5x→4.5x + BigramHash 10240/128 ===" |
| 70 | +env $COMMON BIGRAM_VOCAB_SIZE=10240 BIGRAM_DIM=128 MLP_SCHEDULE="$PROGRESSIVE" RUN_ID="abc_B_progressive" \ |
| 71 | + python "$SCRIPT_DIR/train_gpt.py" \ |
| 72 | + 2>&1 | tee "$LOGDIR/B_progressive.log" |
| 73 | + |
| 74 | +# --- RUN C: Progressive 1.5→4.5× + Fat BigramHash --- |
| 75 | +echo "" |
| 76 | +echo "=== [C] PROGRESSIVE MLP 1.5x→4.5x + BigramHash 16384/192 ===" |
| 77 | +env $COMMON BIGRAM_VOCAB_SIZE=16384 BIGRAM_DIM=192 MLP_SCHEDULE="$PROGRESSIVE" RUN_ID="abc_C_prog_fatbigram" \ |
| 78 | + python "$SCRIPT_DIR/train_gpt.py" \ |
| 79 | + 2>&1 | tee "$LOGDIR/C_prog_fatbigram.log" |
| 80 | + |
| 81 | +# --- RESULTS --- |
| 82 | +echo "" |
| 83 | +echo "============================================" |
| 84 | +echo " A/B/C RESULTS" |
| 85 | +echo "============================================" |
| 86 | +for label in A_uniform B_progressive C_prog_fatbigram; do |
| 87 | + f="$LOGDIR/${label}.log" |
| 88 | + bpb=$(grep -oP "final_int8_zlib_roundtrip val_loss:\S+ val_bpb:\K\S+" "$f" 2>/dev/null | tail -1) |
| 89 | + raw_bpb=$(grep -oP "^step:\d+/\d+ val_loss:\S+ val_bpb:\K\S+" "$f" 2>/dev/null | tail -1) |
| 90 | + steps=$(grep -oP 'stopping_early.*step:\K\d+' "$f" 2>/dev/null | tail -1) |
| 91 | + size=$(grep -oP 'Total submission size int8\+zlib: \K\d+' "$f" 2>/dev/null | tail -1) |
| 92 | + params=$(grep -oP 'model_params:\K\d+' "$f" 2>/dev/null | tail -1) |
| 93 | + echo " ${label}: steps=${steps:-N/A} quant_bpb=${bpb:-N/A} raw_bpb=${raw_bpb:-N/A} params=${params:-N/A} bytes=${size:-N/A}" |
| 94 | +done |
| 95 | +echo "" |
| 96 | +echo " A = control (exact #180)" |
| 97 | +echo " B < A → progressive MLP helps" |
| 98 | +echo " C < B → fat BigramHash compensates thin early layers" |
| 99 | +echo " C < A but B > A → need both together" |
| 100 | +echo "============================================" |
0 commit comments