Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions POD_SETUP.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# RunPod Setup & Training Sequence

Template: `runpod/parameter-golf:latest` — 8×H100 SXM, $21.52/hr on-demand

---

## 1. First-time pod setup

### Install flash-attn (FA2)
The template does not include flash-attn. Install it once per pod:
```bash
pip install flash-attn --no-cache-dir --no-build-isolation
```
Takes ~10 minutes to compile. `--no-build-isolation` is required so the build can find the
already-installed torch. `--no-cache-dir` avoids a cross-device link error on this filesystem.

Verify:
```bash
python3 -c "from flash_attn import flash_attn_func; print('FA2 available')"
```

### Install zstandard
```bash
pip install zstandard
```

---

## 2. Clone the repo

The `/workspace/parameter-golf` directory from the template is not a git repo. Remove it and clone:
```bash
cd /workspace && rm -rf parameter-golf
git clone https://github.com/mrdavtan/parameter-golf.git
cd parameter-golf
git checkout 11l-xsa-ema-ttt # current active branch
```

---

## 3. Download dataset and tokenizer

Only needed once per pod (persists in `/workspace/parameter-golf/data/`):
```bash
python3 data/cached_challenge_fineweb.py --variant sp1024
```

This downloads ~8B tokens of FineWeb shards + the sp1024 tokenizer from HuggingFace. Takes a few minutes.

Verify:
```bash
ls data/datasets/fineweb10B_sp1024/
ls data/tokenizers/
```

---

## 4. Testing workflow — always follow this sequence

**Never stack multiple new features in a single full run. Test one at a time.**

### Step 1 — establish the Tier 2 baseline (~3 min, ~$1)

Run the current best config with all schedule-dependent features off:
```bash
git pull
unset MLP_HIDDEN QUANT_BITS RUN_ID SEED
TIER2_MODE=1 torchrun --standalone --nproc_per_node=8 \
records/track_10min_16mb/2026-03-21_11L_XSA_EMA_TTT/train_gpt.py
```
Record **val_bpb at step 2000**. This is your baseline for the current session.
Expected: `*** TIER2_MODE ***` banner in startup log.

### Step 2 — Tier 2 test of a new feature (~3 min, ~$1)

Enable exactly ONE new feature:
```bash
TIER2_MODE=1 XSA_LAST_N=0 torchrun ... # test: disable XSA
TIER2_MODE=1 SMEAR_GATE=0 torchrun ... # test: disable SmearGate
TIER2_MODE=1 NUM_LAYERS=9 torchrun ... # test: fewer layers
```
If val_bpb@step2000 is worse than baseline → skip the feature.
If better → proceed to Tier 3.

### Step 3 — full 10-minute run (Tier 3, ~$3.60)

Only run after Tier 2 shows improvement:
```bash
unset MLP_HIDDEN QUANT_BITS RUN_ID SEED
torchrun --standalone --nproc_per_node=8 \
records/track_10min_16mb/2026-03-21_11L_XSA_EMA_TTT/train_gpt.py
```
Startup log should confirm:
- `*** TIER2_MODE ***` NOT present
- `ntk_rope:enabled train_seq_len:1024 eval_seq_len:2048`
- `xsa_last_n:4 active_layers:[7, 8, 9, 10]`
- `ema:initialized decay=0.997`
- `qat:True (activates when 480s elapsed; guarantees 120s of QAT)`
- step_avg ~50ms (NTK-RoPE at seq_len=1024), targeting ~10000+ steps

### What TIER2_MODE cannot test

These features only matter at the end of training — skip them in Tier 2:
| Feature | Why Tier 2 can't test it | How to test |
|---------|--------------------------|-------------|
| EMA | Benefits from averaging converged weights, not early-stage weights | Full run only |
| TTT | Needs a well-trained model to adapt | Full run only |
| SWA | Requires many checkpoint samples across long warmdown | Full run only |
| QAT | Disabled in TIER2_MODE automatically | Full run only |

---

## 5. IMPORTANT — always unset env vars before any run

```bash
unset MLP_HIDDEN QUANT_BITS RUN_ID SEED
echo "MLP_HIDDEN=${MLP_HIDDEN} QUANT_BITS=${QUANT_BITS}" # should be empty
```

MLP_HIDDEN inherited from a prior run will silently bloat the model from ~26.8M to ~31.2M params,
causing artifact > 16MB. This has happened before.

---

## Notes
- PyTorch version on template: `2.9.1+cu128`
- `flash_attn_interface` (FA3) is NOT available — use FA2 (`flash_attn`)
- FA3 install fails due to cross-device link error on this pod filesystem
- Data path: `./data/datasets/fineweb10B_sp1024/`
- Tokenizer path: `./data/tokenizers/fineweb_1024_bpe.model`
- Logs saved to `./logs/<RUN_ID>.txt`
- At NTK-RoPE seq_len=1024: expect ~50ms/step, ~10000 steps in 600s
- At seq_len=2048 (no NTK): expect ~85-92ms/step, ~6500-7000 steps in 600s
58 changes: 58 additions & 0 deletions records/track_10min_16mb/2026-03-20_Int6_3xMLP/INT5_ABLATION.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Ablation: Int5 Quantization (QUANT_BITS=5, MLP_HIDDEN=1920)

**Date:** 2026-03-20
**Script:** `records/track_10min_16mb/2026-03-20_Int6_3xMLP/train_gpt.py`
**Config:** `QUANT_BITS=5 MLP_HIDDEN=1920 SEED=1337`
**Verdict:** ❌ Negative — catastrophic quantization degradation

---

## Results

| Metric | Value |
|--------|-------|
| Pre-quant val_bpb (step 10200, wallclock stop) | **1.1885** |
| Post-quant val_bpb (int5 roundtrip) | **1.5458** |
| Quantization gap | **+0.357 bpb** |
| Step avg | 58.9ms |
| Steps completed | 10,200 (600s cap) |

### Comparison vs Int6

| | Pre-quant val_bpb | Post-quant val_bpb | Quant gap |
|---|---|---|---|
| **Int6** (our submission) | 1.1949 | **1.1708** | +0.024 |
| **Int5** (this ablation) | 1.1885 | 1.5458 | +0.357 |

Int5's quantization gap is **15× larger** than int6's.

---

## Why Int5 Fails

Int5 per-row quantization maps weights to the range `[-15, 15]` — only **31 levels**.
Int6 uses `[-31, 31]` — **63 levels**.
Int8 uses `[-127, 127]` — **255 levels**.

Each step down halves the number of representable values. The jump from int6 to int5 is the same relative reduction as int8 to int6, but int6 was already at the edge of viability. With only 31 levels, per-row quantization error becomes large enough to destroy the model's language modeling ability entirely.

The pre-quant score (1.1885) is actually slightly **better** than our int6 run's pre-quant (1.1949) — int5 frees enough artifact budget to fit MLP_HIDDEN=1920 (vs 1536 for int6), which gives more parameters and better raw training. But the quantization destroys all of that gain and then some.

## Side note: Triton QAT warnings

The run produced warnings:
```
tl.where with a non-boolean condition is deprecated... Got int8
```

These come from the QAT (quantization-aware training) triton kernel, which still uses int8 internally regardless of `QUANT_BITS`. QAT and the export quantization are independent: QAT simulates int8 during training, while `QUANT_BITS=5` only affects the post-training export serialization. This means int5 QAT (simulating 5-bit during training) was never actually tested — that would require modifying the QAT kernel itself.

---

## Conclusion

**Int5 is not viable with current post-training quantization.** The quantization gap (+0.357) far outweighs any capacity gain from the extra artifact headroom.

A potential future path would be **Int5 QAT** — actually simulating 5-bit quantization during forward passes so the model learns to be robust to int5 rounding. This would require modifying the QAT triton kernel to use 5-bit fake quantization instead of 8-bit. Given that even int8 QAT was a negative finding (PR #145), this is unlikely to be worth pursuing.

**Next experiment:** 11L + FA2 + WD=0.04 (`11l-fa3-wd04` branch)
118 changes: 118 additions & 0 deletions records/track_10min_16mb/2026-03-20_Int6_3xMLP/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# 2026-03-20_Int6_3xMLP

**Mean val_bpb = 1.1724** (5 seeds, std=0.0026, p<0.01 vs baseline)

Int6 per-row quantization with 3x MLP expansion, accompanied by 9 controlled ablations.

### Statistical Validation

| Seed | val_bpb |
|------|---------|
| 31337 | 1.1703 |
| 1337 | 1.1708 |
| 2024 | 1.1712 |
| 42 | 1.1732 |
| 7 | 1.1767 |
| **Mean** | **1.1724** |
| **Std** | **0.0026** |

---

## Approach

Int6 per-row quantization stores weights in 6 bits ([-31, 31]) instead of int8's 8 bits ([-127, 127]). Combined with zstd-22 compression, this saves ~3MB of artifact budget compared to int8+zlib, enabling a 3x MLP expansion (hidden=1536 vs baseline's 1024). The wider MLP adds 4.8M parameters.

### Configuration

```
VOCAB_SIZE=1024 NUM_LAYERS=9 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4
MLP_HIDDEN=1536 TIE_EMBEDDINGS=1 FP16_EMBED_EXPORT=1
QUANT_BITS=6 USE_ZSTD=1
WARMDOWN_ITERS=20000 MATRIX_LR=0.06 SCALAR_LR=0.06 TIED_EMBED_LR=0.07
GRAD_CLIP_NORM=1.0 MUON_BACKEND_STEPS=5 MUON_MOMENTUM=0.99
EVAL_STRIDE=64 MAX_WALLCLOCK_SECONDS=600
```

### Key Metrics

| Metric | Naive Baseline | Int6 + 3xMLP |
|--------|---------------|-----------------|
| model_params | 17,059,912 | 21,778,504 |
| Pre-quant val_bpb | 1.2172 | 1.1949 |
| **Post-quant val_bpb** | **1.2244** | **1.1708** |
| Artifact size | 15,863,489 (int8+zlib) | 15,175,136 (int6+zstd) |
| Artifact headroom | 137KB | 825KB |
| Steps | ~13,800 | 12,507 |
| step_avg | 43.5ms | 48.0ms |

---

## Ablations

| # | Technique | val_bpb | vs Control (1.1929) | Verdict |
|---|-----------|---------|--------------------|---------|
| 1 | SWA (weight averaging) | 1.1933 | +0.0004 | **No effect** at WD=1200 |
| 2 | Doc-isolated eval | 1.2015 | +0.0086 | **Hurts** at stride=64 |
| 3 | Curriculum learning | 1.1942 | +0.0013 | **No effect** |
| 4 | Multi-token prediction | 1.1947 | +0.0018 | **No effect** |
| 5 | **Int6 + 3x MLP** | **1.1708** | **-0.0221** | **Best result** |
| 6 | + SmearGate + BigramHash | 1.1739 | -0.0190 | **Hurts** on top of int6 |
| 7 | Depth recurrence + Huginn (skips) | 4.34 | — | **Catastrophic** |
| 8 | Depth recurrence + Huginn (flat) | 5.58 | — | **Catastrophic** |
| 9 | Int8 QAT (PR #145) | 1.2052 | +0.0123 | **Overhead exceeds recovery** |

### Key Negative Findings

**1. Doc-isolated eval hurts at stride=64**

The LoRA TTT entry (#77) found doc-isolation worth +0.011 BPB at stride=256. At stride=64, it costs -0.009 BPB. At stride=64, tokens already have 960+ tokens of context. Removing cross-doc context at document boundaries means start-of-document tokens lose all context, which hurts more than cleaner context helps. There is a crossover stride length between 64 and 256 where doc-isolation flips from harmful to helpful.

**2. SmearGate + BigramHash hurt with int6**

SmearGate + BigramHash have been reported as helpful in other entries, but on the int6+3xMLP base they cost +0.003 BPB. BigramHash adds ~524K params that get int6 quantized and had insufficient training steps. The implementations may differ from the originals, or the gains require interaction with other techniques (OrthoInit, specific SWA schedule).

**3. Huginn eval-time scaling fails at small scale**

Depth recurrence (3 shared blocks × 3 loops = 9 effective layers) with Huginn-style eval scaling (6 loops at eval) produces random output (4.34-5.58 BPB). Tested both with U-Net skips (skips disabled for extra loops) and flat loops (trained without skips). Neither works. The 3 shared blocks at 7.6M params lack sufficient capacity to learn general iterative refinement. Huginn was validated at 3.5B — the technique does not transfer to 7.6M scale.

**4. Int8 QAT overhead exceeds recovery**

Exact INT8_CLIP_Q percentile matching via `torch.quantile` adds ~20% per-step overhead, costing ~2,000 training steps. The lost training tokens hurt more than the ~0.007 BPB quantization gap recovery. QAT likely only pays off with int6 (larger gap to close) using a faster approximate quantile.

### Implementation Bugs Discovered

**SWA bf16 accumulation:** Initial SWA implementation accumulated weights in bf16, producing val_bpb=2.62 after thousands of additions. Fix: accumulate in float32, sample every 50 steps.

**torch.compile graph priming:** Pre-compiling both QAT and non-QAT graphs during warmup caused 50% step time regression for the non-QAT path. Fix: don't pre-prime conditional code paths.

**zstd/zlib decompression mismatch:** Compressing with zstd then decompressing with zlib crashes. Fix: match decompressor to compressor.

---

## Reproduction

```bash
cd /workspace
git clone https://github.com/mrdavtan/parameter-golf.git
cd parameter-golf && git checkout int6-3xMLP-pr
pip install sentencepiece huggingface_hub zstandard
python3 data/cached_challenge_fineweb.py --variant sp1024

torchrun --standalone --nproc_per_node=8 \
records/track_10min_16mb/2026-03-20_Int6_3xMLP/train_gpt.py
```

Environment variables listed in Configuration section above.

Hardware: 8×H100 SXM (RunPod Parameter Golf template), PyTorch 2.9.1+cu128

---

## Acknowledgments

- Int6 quantization approach studied from WarmdownQuantization entry by @samuellarson
- Sliding window evaluation from #50 by @mattqlf
- Hyperparameter tuning informed by #65 (@samuellarson) and #128 (@rsavitt)
- SmearGate/BigramHash implementations based on modded-nanogpt community
- Depth recurrence inspired by PR #167 and Huginn (arxiv 2502.05171)
- Doc-isolated eval concept from LoRA TTT entry (#77) by @samacquaviva
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
logs/int6_3xMLP_seed1337.txt
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
smear_gate:False bigram_hash:False swa:False
fp16_embed_export:enabled (tok_emb.weight kept in fp16, ~1024KB)
qat:False (activates at 60% of iterations = step 6000)
model_params:21778504 (unique_layers:9 loops:1 effective_depth:9 lora_rank:0 lora_params:0)
world_size:8 grad_accum_steps:1
sdp_backends:cudnn=False flash=True mem_efficient=False math=False
attention_mode:gqa num_heads:8 num_kv_heads:4
tie_embeddings:True embed_lr:0.07 head_lr:0.0 matrix_lr:0.06 scalar_lr:0.06
train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000
seed:1337
step:200/20000 train_loss:2.7747 train_time:11038ms step_avg:55.19ms
step:400/20000 train_loss:2.3053 train_time:20631ms step_avg:51.58ms
step:600/20000 train_loss:2.5173 train_time:30120ms step_avg:50.20ms
step:800/20000 train_loss:2.2576 train_time:39662ms step_avg:49.58ms
step:1000/20000 train_loss:2.3326 train_time:49256ms step_avg:49.26ms
step:2000/20000 train_loss:2.2082 train_time:97479ms step_avg:48.74ms
step:4000/20000 train_loss:2.1104 train_time:192942ms step_avg:48.24ms
step:6000/20000 train_loss:2.1675 train_time:288541ms step_avg:48.09ms
step:8000/20000 train_loss:2.0480 train_time:384151ms step_avg:48.02ms
step:10000/20000 train_loss:2.0419 train_time:479774ms step_avg:47.98ms
step:12000/20000 train_loss:1.9941 train_time:575666ms step_avg:47.97ms
step:12507/20000 val_loss:2.0176 val_bpb:1.1949 train_time:599980ms step_avg:47.97ms
stopping_early: wallclock_cap train_time:599980ms step:12507/20000
peak memory allocated: 11250 MiB reserved: 11380 MiB
Serialized model: 86099351 bytes
Code size: 71265 bytes
Total submission size: 86170616 bytes
quantization: 6-bit
Serialized model int8+zstd-22: 15103871 bytes (payload:22428960 raw_torch:22473755 payload_ratio:3.84x)
Total submission size int8+zstd-22: 15175136 bytes
final_eval_mode:sliding_window stride:64 batch_seqs:32 doc_isolated:False
final_int8_zlib_roundtrip val_loss:1.9768 val_bpb:1.1708 eval_time:80311ms
final_int8_zlib_roundtrip_exact val_loss:1.97676768 val_bpb:1.17075317
---
Run config: QUANT_BITS=6 USE_ZSTD=1 MLP_HIDDEN=1536 FP16_EMBED_EXPORT=1
WARMDOWN_ITERS=20000 MATRIX_LR=0.06 SCALAR_LR=0.06 TIED_EMBED_LR=0.07
GRAD_CLIP_NORM=1.0 MUON_BACKEND_STEPS=5 MUON_MOMENTUM=0.99
NUM_LAYERS=9 MODEL_DIM=512 EVAL_STRIDE=64 SEED=1337
Hardware: 8xH100 SXM (RunPod Parameter Golf template), PyTorch 2.9.1+cu128
Loading