Skip to content

Commit ea6e666

Browse files
committed
Add setup script for 12L QAT Int4-MLP submission
1 parent 20a38ef commit ea6e666

1 file changed

Lines changed: 92 additions & 0 deletions

File tree

  • records/track_10min_16mb/2026-03-25_QAT_Int4MLP_12L
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#!/bin/bash
2+
# ---------------------------------------------------------------
3+
# Setup script for 12L QAT Int4-MLP submission
4+
# Run from the repo root: bash records/track_10min_16mb/2026-03-25_QAT_Int4MLP_12L/setup.sh
5+
# ---------------------------------------------------------------
6+
set -e
7+
8+
REPO_ROOT="$(cd "$(dirname "$0")/../../.." && pwd)"
9+
cd "$REPO_ROOT"
10+
11+
echo "============================================"
12+
echo " 12L QAT Int4-MLP — Environment Setup"
13+
echo "============================================"
14+
15+
# ---------------------------------------------------------------
16+
# 1. Python dependencies
17+
# ---------------------------------------------------------------
18+
echo ""
19+
echo "[1/3] Installing Python dependencies..."
20+
pip install --upgrade pip -q
21+
pip install numpy tqdm sentencepiece huggingface-hub -q
22+
23+
TORCH_VER=$(python3 -c "import torch; print(torch.__version__)" 2>/dev/null || echo "none")
24+
if [[ "$TORCH_VER" != *"2.9"*"+cu128"* ]]; then
25+
echo " Upgrading torch to 2.9.1+cu128 (current: $TORCH_VER)..."
26+
pip install torch --index-url https://download.pytorch.org/whl/cu128 --no-cache-dir --force-reinstall -q
27+
else
28+
echo " torch $TORCH_VER already OK."
29+
fi
30+
echo " Done."
31+
32+
# ---------------------------------------------------------------
33+
# 2. Flash Attention 3 (Hopper)
34+
# ---------------------------------------------------------------
35+
echo ""
36+
echo "[2/3] Installing Flash Attention 3..."
37+
38+
if python3 -c "from flash_attn_interface import flash_attn_func; print('ok')" 2>/dev/null | grep -q ok; then
39+
echo " Already installed and working — skipping."
40+
else
41+
pip install flash_attn_3 --no-deps --force-reinstall --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291/
42+
echo " Installed."
43+
fi
44+
45+
# ---------------------------------------------------------------
46+
# 3. Dataset + Tokenizer (sp1024)
47+
# ---------------------------------------------------------------
48+
echo ""
49+
echo "[3/3] Downloading dataset (sp1024)..."
50+
python3 data/cached_challenge_fineweb.py --variant sp1024
51+
echo " Done."
52+
53+
# ---------------------------------------------------------------
54+
# Verification
55+
# ---------------------------------------------------------------
56+
echo ""
57+
echo "============================================"
58+
echo " Verification"
59+
echo "============================================"
60+
61+
python3 - << 'PYEOF'
62+
import sys, torch, glob, numpy as np
63+
64+
print(f"Python : {sys.version.split()[0]}")
65+
print(f"PyTorch : {torch.__version__}")
66+
print(f"CUDA : {torch.cuda.is_available()}")
67+
print(f"GPUs : {torch.cuda.device_count()}")
68+
69+
for i in range(torch.cuda.device_count()):
70+
p = torch.cuda.get_device_properties(i)
71+
print(f" GPU {i} : {p.name} ({p.total_memory // 1024**3}GB)")
72+
73+
try:
74+
from flash_attn_interface import flash_attn_func
75+
print("FlashAttn3 : OK")
76+
except ImportError:
77+
print("FlashAttn3 : MISSING!")
78+
79+
train = sorted(glob.glob("./data/datasets/fineweb10B_sp1024/fineweb_train_*.bin"))
80+
val = sorted(glob.glob("./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin"))
81+
print(f"Train shards : {len(train)}")
82+
print(f"Val shards : {len(val)}")
83+
PYEOF
84+
85+
echo ""
86+
echo "============================================"
87+
echo " Setup complete. Run training with:"
88+
echo ""
89+
echo " tmux"
90+
echo " torchrun --nproc_per_node=8 records/track_10min_16mb/2026-03-25_QAT_Int4MLP_12L/train_gpt.py"
91+
echo ""
92+
echo "============================================"

0 commit comments

Comments
 (0)