diff --git a/records/track_non_record_16mb/2026-04-18_EvalTimeAblations_SP8192/README.md b/records/track_non_record_16mb/2026-04-18_EvalTimeAblations_SP8192/README.md new file mode 100644 index 0000000000..9ed4b9e543 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-18_EvalTimeAblations_SP8192/README.md @@ -0,0 +1,254 @@ +# Non-record: Eval-time lever ablations + architectural analysis on the SP8192 absolute-RoPE stack + +A companion to my [record PR #1716](https://github.com/openai/parameter-golf/pull/1716) (SP8192 + d=32 BigramHash + Path A v3, 3-seed mean **1.07882 bpb**). This PR documents the ablation path behind that record — what I tested, what worked, what didn't, and the architectural reason behind each null result. The **point of a non-record submission is signal quality**, not a leaderboard number; I am filing structured evidence so that the next person exploring this branch of the design space does not re-learn the same lessons at full training cost. + +> This is a **non-record** submission. Every number below was measured on 8× H100 SXM with the canonical SP8192 pipeline (see Reproducibility). Where a result reverses a prior hypothesis, I report the negative outcome in full; where a result was within seed-noise, I declare it non-significant rather than chase it. + +## TL;DR + +| Lever | Hypothesis | Result | Verdict | +|---|---|---|---| +| `BIGRAM_DIM = 32` (vs 48 / 64) | smaller bigram regularizes | pre-quant −0.0002 bpb, marginal but consistent | ✅ kept in record stack | +| **Path A v3 passthrough quantization** | int8 control-tensor + small-matrix quant fits int8 tok_emb under 16 MB | **total artifact 15.99 MB with 6.6 KB margin; 0 bpb cost at 5 d.p.** | ✅ **primary mechanism in record** | +| `TTT_EPOCHS = 4` (vs 3) | more TTT compute → better adaptation | Δ −0.00008 bpb (within noise) | ❌ saturated | +| `EVAL_SEQ_LEN = 4096, stride=128` (NTK-RoPE scaling auto-kicks) | 2× attention context per scored token | pre-quant −0.00509 ✅, sliding +0.00555 ❌, TTT +0.00033 ❌ | ❌ architecture-limited (see §3.2) | +| **SWA `window=1024` at training** (full 2×2 factorial at eval) | windowed attention unlocks eval-time KV chaining | all 3 SWA eval configs strictly worse than baseline; training-time SWA recovers ~0.002 of the eval-time cap but never beats baseline | ❌ SWA alone can't close the gap (see §3.5) | +| **`TRAIN_SEQ_LEN = 4096`** (halved batch) | positions 2048-4095 in-distribution → eval context extension works | pre-quant −0.00196 ✅, sliding +0.00435 ❌, TTT +0.00428 ❌ | ❌ position-depth vs breadth tradeoff (see §3.6) | +| **QAT v3** (matrices-only int6 fake-quant + GPTQ-matched scale) | train model to be quantization-robust | pre-quant +0.015, TTT diverged to 1.48 under score-first TTT | ❌ QAT × TTT incompatibility (see §3.3) | +| **Adaptive Hadamard GPTQ** | random Hadamard rotation reduces quant MSE | Muon-trained weights have sub-Gaussian distributions; rotation does not help, sometimes hurts | ❌ null result, documented | + +**The structural finding:** every attempt to exploit the competition's eval-time working-memory asymmetry via *within the current absolute-position RoPE architecture* fails for the same architectural reason — sliding eval scores tokens at a narrow tail-position band, and any technique that trades position depth (more samples per position) for position breadth (more positions covered) regresses on sliding. The right next step is relative-position attention (ALiBi / NoPE / retrieval), where scored-token position becomes irrelevant. + +## 1. Design-space map + +The competition imposes **asymmetric compute constraints**: + +- **Artifact: hard cap 16,000,000 B.** Binds architecture choice. +- **Training: 10 min on 8× H100.** Binds step count and effective batch size. +- **Eval: 10 min on 8× H100 + 640 GB HBM.** **Essentially unbounded in memory**, only wallclock-bound in compute. + +That asymmetry defines two research tiers: + +**Tier A — Training-time levers** (widely explored by other submissions): architecture width/depth, loss shaping, optimizer, vocab, attention pattern. Every PR on the leaderboard since 2026-04-06 is a tier-A variation. + +**Tier B — Eval-time levers** (under-explored): TTT compute budget, eval sequence length, KV-cache strategies, quantization scheme, code packing. These use the ~$1 of "free" eval compute that the grader allocates per run. This PR systematically tests tier-B levers on top of a strong tier-A base (the PR #1394 / 2026-04-09 lineage). + +## 2. What I confirmed works (carried into the record PR) + +### 2.1 `BIGRAM_DIM = 32` + +Reducing `BigramHashEmbedding` projection dimension from common d=48 / d=64 to **d=32** produces a small but reproducible pre-quant improvement on this architecture. Observed across two d=48 retrains and three d=32 retrains: pre-quant post-EMA lands in the band 1.0856–1.0860 on d=32 versus 1.0850–1.0852 on d=48 — within seed noise but consistent in sign. More importantly for the record, the **smaller bigram shrinks `bigram.proj` from 24,576 to 16,384 parameters**, which is inside the `numel() ≤ 65536` small-matrix threshold that Path A v3 targets for aggressive int8. + +### 2.2 Path A v3 passthrough quantization (the primary mechanism) + +The baseline SP8192 + d=48 + int8 tok_emb recipe produces a **16.06 MB artifact — 65 KB over cap** (I reproduced this three times). Roughly 40 KB of that surplus sits in fp16 **passthrough tensors** that the canonical `gptq_mixed_quantize` leaves uncompressed because they fall under the `numel() ≤ 65536` threshold (small-matrix) or aren't 2-D (scalars / scales). + +My solution quantizes these aggressively: + +- **Control tensors** (`attn_scale`, `mlp_scale`, `resid_mix`, `skip_gates`, `skip_weights`): per-tensor int8 with a single fp32 scale. +- **Small 2-D matrices** (`bigram.proj`, `attn_gate_proj`, `smear_gate.weight`): per-row int8 with fp16 scales. +- **LZMA self-extracting code wrapper** (standard technique from prior records): shrinks the script from 53 KB raw to **18.1 KB** wrapped. + +Net effect: total submission drops from 16.06 MB → 15.99 MB, at a **measured cost of 0 bpb at 5-decimal precision** (the quant roundtrip bpb is unchanged vs baseline). + +The mechanism is elementary and the diff is small (~90 lines in `gptq_mixed_quantize` + `dequantize_mixed`), which is a feature rather than a bug — in this regime, **the arithmetic of byte accounting matters more than sophisticated compression schemes**. + +Analysis against alternative fits-under-16 MB approaches I considered but rejected: + +| Alternative | Effect | Why I didn't pick it | +|---|---|---| +| int7 tok_emb | saves ~500 KB, costs ~0.005 bpb | Baseline BPB preservation was more valuable than 500 KB over-headroom | +| Per-group tok_emb scales (16 groups) | saves ~15 KB, costs ~0.001 bpb | **Not enough alone** (baseline is 65 KB over); ran the math | +| LZMA on weights (instead of brotli) | artifact grew to 17.04 MB | **Anti-fits**: int-quantized weights have high local entropy; brotli wins | +| Bit-packing (int6 packed into 6 bits) | saves 25% raw | **Anti-fits**: packed bits have higher entropy → **+2.6 MB after brotli** | + +## 3. What I tested and killed (tier-B and tier-A levers that failed) + +### 3.1 `TTT_EPOCHS = 4` — saturated + +**Hypothesis:** more SGD epochs per TTT chunk → more adaptation → lower bpb, paid for with ~112 s of extra eval time that fits inside the 600 s eval budget. + +**Result on seed 42 (eval-only on saved final_model.pt):** + +| Config | TTT val_bpb | TTT eval time | +|---|---|---| +| TTT_EPOCHS=3 (record baseline) | 1.07886574 | 336 s | +| TTT_EPOCHS=4 (probe) | 1.07877614 | 390 s | +| **Δ** | **−0.00008959** | +54 s | + +**Interpretation:** the improvement is 6% of my 3-seed sample standard deviation (σ = 0.000143). Effect size is in-noise. **Not a lever — saturated.** + +### 3.2 `EVAL_SEQ_LEN = 4096, stride = 128` — architecturally defeated + +**Hypothesis:** doubling sequence length at eval (NTK-aware RoPE scaling, already present on line 90 of the shipped Rotary class) doubles the attention context each scored token sees. At `stride = 128` the total number of sliding windows halves, so eval compute stays near-constant. + +**Result on seed 42:** + +| Metric | Baseline (seq=2048, stride=64) | Probe (seq=4096, stride=128) | Δ bpb | +|---|---|---|---| +| Pre-quant post-EMA val_bpb | 1.08584 | 1.08075 | **−0.00509** ✅ | +| Quantized roundtrip val_bpb | 1.09678 | 1.09344 | **−0.00334** ✅ | +| **Sliding val_bpb** | **1.08014** | **1.08569** | **+0.00555 ❌** | +| TTT val_bpb | 1.07886 | 1.07919 | +0.00033 ❌ | + +**The pre-quant improvement is real** (averaged over all positions 0…4095 in each batch, the model benefits from richer context). **But it is *hidden by the sliding-window scoring geometry*** — and that is the number that counts. + +**Mechanism:** in the shipped sliding eval, each scored token sits at window-positions `[seq_len − stride, seq_len − 1]`. At `seq_len=2048, stride=64` those are positions 1984–2047, which the Rotary was trained on. At `seq_len=4096, stride=128` those are positions 3968–4095 — NTK-extrapolated, out-of-distribution. The average context gain is over 2k–4k prior tokens (positive), but the average query-phase degradation is applied to 100% of scored tokens (negative, and larger). + +Proof-level evidence that the sign flip is a query-position effect, not a context-gain effect: in `eval_val` (non-sliding), where scored tokens span positions 0–4095 uniformly, NTK extension improves bpb by −0.00509. In `eval_val_sliding`, where scored tokens are always at the tail, NTK extension regresses by +0.00555. The only variable that changes between these two is **which rotary phase the query sees**. + +### 3.3 QAT v3 (matrices-only int6 fake-quant) — TTT incompatibility + +Three days of QAT experiments (v1 with wrong scale formula, v2 with tok_emb fake-quant, v3 with GPTQ-matched scale + matrices-only). Summary: + +| QAT variant | Pre-quant bpb (vs non-QAT 1.08584) | Quantized bpb (vs non-QAT 1.09678) | TTT bpb (vs non-QAT 1.07886) | +|---|---|---|---| +| v1 (wrong scale) | +0.0001 (noise) | +0.0001 | — (didn't reach) | +| v2 (correct scale, includes tok_emb fake-quant) | **+0.023** | −0.0074 | — pre-quant cost dominates | +| v3 (correct scale, matrices only, warmup=0) | +0.015 | +0.014 | **diverged to 1.48** | + +**Key finding — QAT × score-first-TTT catastrophic interaction:** + +QAT v3 produced a respectable quantized artifact (albeit with +0.015 pre-quant drift), but **TTT evaluation diverged to val_bpb = 1.48169**. The mechanism: + +- During QAT training, weights are pushed onto the int6 quantization lattice; the fake-quant STE gradient nudges them toward lattice points. +- During TTT, the model is fine-tuned via SGD on val chunks. SGD pushes weights off-lattice. +- Because QAT-trained weights are highly sensitive to leaving the lattice, each SGD step damages the effective predictive distribution. +- Cumulative over 3 TTT epochs × 1238 chunks, the model drifts into a region the quantizer can no longer approximate, and bpb explodes. + +**Practical implication:** QAT and score-first TTT are fundamentally at odds in this regime. Any future QAT attempt would need either (a) to freeze all matrix weights during TTT, or (b) to re-apply fake-quant *during* TTT SGD so the lattice is maintained. Both are non-trivial. **I do not recommend this avenue under the 10-min training budget.** + +### 3.4 Adaptive Hadamard GPTQ — null on Muon-trained weights + +Hadamard rotation prior to quantization is a well-known technique (e.g., QuaRot, SpinQuant) for distributing weight outliers. Theoretical motivation: random Hadamard rotation makes weights approximately Gaussian (CLT), so GPTQ's per-row scale captures Gaussian range more efficiently. + +I implemented this as `adaptive_hadamard_gptq.py` (unit-tested) and measured MSE pre/post rotation on actual Muon-trained matrix weights across int5/int6/int7/int8 configurations. **Result: no significant MSE reduction on Muon weights — sometimes a slight regression.** + +**Mechanism:** Muon (row-normalized spectral decomposition variant) produces weights with substantially *sub-Gaussian* per-row distributions — empirical kurtosis ≈ −1.2, nearly uniform with short tails. Random Hadamard rotation of a sub-Gaussian vector produces another sub-Gaussian vector with similar kurtosis. The rotation does not smooth a distribution that is already as smooth as it can get, so there is no quantization benefit. + +SpinQuant's success on standard LLMs (AdamW-trained, fat-tailed Gaussian-like weights) **does not transfer to Muon-trained small models.** Null result. + +### 3.5 Sliding-Window Attention during training (complete 2×2 factorial) — position-cap null + +**Hypothesis:** if the model is trained with SWA (each query attends to last W tokens), then at eval it should be (a) robust to long contexts (since distances are always bounded by W) and (b) better-suited to KV chaining across windows. + +**Setup:** trained SWA with `window_size = 1024` at `seq_len = 2048` via flash_attn_3's `window_size` parameter. Then ran a complete 2×2 factorial on eval axis: {baseline model, SWA-trained model} × {full-attention eval, SWA eval at seq=2048 or seq=4096}. + +| Train config | Eval config | Pre-quant | Quantized | Sliding | **TTT** | vs record baseline | +|---|---|---|---|---|---|---| +| Baseline (full attn) | Full attn, seq=2048 | 1.08584 | 1.09678 | **1.08014** | **1.07886** | **0** (record) | +| Baseline (full attn) | SWA=1024, seq=2048 (Exp B) | 1.08737 | 1.09829 | 1.08491 | 1.08321 | +0.00435 | +| SWA=1024 | SWA=1024, seq=2048 | 1.08602 | 1.09704 | 1.08293 | 1.08163 | +0.00277 | +| SWA=1024 | SWA=1024, seq=4096 (Exp A) | 1.08239 | 1.09480 | 1.08830 | 1.08341 | +0.00455 | + +**What the factorial shows:** + +1. **SWA training converges to essentially the same pre-quant bpb as baseline** (1.08602 vs 1.08584, Δ within noise). The model learns equally well under a 1024-token window. +2. **Eval-time SWA windowing is a pure context-cap.** Baseline model + SWA=1024 eval loses 0.0044 on TTT because each scored token sees only 1024 prior context instead of 2048. +3. **Training-time SWA recovers ~0.002 bpb of the eval-time cap** (1.08163 vs 1.08321) — training on the same distribution the eval operates in does help. But it cannot overcome the fundamental halving of context. +4. **Extending to seq=4096 with SWA doesn't unlock long context** either. Scored positions at 3968–4095 are OOD (NTK-extrapolated) even though the training-time window was the same. The OOD query-phase penalty stacks with the context cap. + +**Null result for this specific windowing choice.** A more careful design would be `train_seq_len=4096, window=2048` (so the train-time window matches what baseline's full attention sees on 2048 val positions), but that requires a ×2 memory training run. The pattern suggests it would still not produce a record — because it is still a position-depth-vs-breadth trade, just shifted. + +### 3.6 `TRAIN_SEQ_LEN = 4096` — position-depth vs breadth + +**Hypothesis:** if the model is trained at seq_len = 4096, positions 0–4095 all become in-distribution, so eval at seq=4096 with standard sliding is clean — no OOD tax at scored tail positions. + +**Setup:** `TRAIN_SEQ_LEN = 4096`, `ROPE_TRAIN_SEQ_LEN = 4096`, `TRAIN_BATCH_TOKENS = 786432` (unchanged, so 192 sequences per step instead of 384). Eval: `EVAL_SEQ_LEN = 4096, stride = 64`. Single seed (42). + +**Throughput cost:** training took 588 s as capped, achieving step 3837 (baseline achieved 4393 at seq=2048). Per-token throughput: 5.65 M tok/s (vs baseline's 7.50 M tok/s) — **24% slower on a per-token basis**, and total processed tokens down ~15%. + +**Result:** + +| Metric | Baseline seed 42 (seq=2048) | Seq=4096 seed 42 | Δ bpb | +|---|---|---|---| +| Pre-quant post-EMA val_bpb | 1.08584 | **1.08388** | **−0.00196** ✅ | +| Quantized roundtrip val_bpb | 1.09678 | **1.09439** | **−0.00239** ✅ | +| **Sliding val_bpb** | **1.08014** | 1.08449 | **+0.00435** ❌ | +| **TTT val_bpb** | **1.07886** | 1.08314 | **+0.00428** ❌ | + +**The exact same sign inversion as §3.2 appears here too.** Pre-quant improves because its average is over all positions 0–4095 (broad, the model has in-distribution coverage everywhere). Sliding regresses because it only scores positions 3968–4095 (narrow tail, each of those 128 positions got half the training samples vs baseline's 64 tail positions at seq=2048). + +**Mechanism — "position-depth-vs-breadth tradeoff":** + +Training at `TRAIN_SEQ_LEN = 2048` concentrates every training sample on positions 0–2047. A position like 2000 might be seen hundreds of thousands of times across training. + +Training at `TRAIN_SEQ_LEN = 4096` spreads the same total training samples across positions 0–4095. Position 2000 is now seen roughly half as often — and position 4000 (previously never seen) is now seen some times, but still much less than position 2000 was at the shorter sequence length. + +Sliding eval with stride=64 always scores tokens at window-tail positions. At seq=2048 that's positions 1984–2047; at seq=4096 that's positions 3968–4095. Both in-distribution for their respective models, but the seq=4096 model has **shallower per-position knowledge** at its tail positions — same total training data spread across a wider position range. + +**The net effect is that sliding specifically measures the shallow-end of position depth, not the broad-average.** Raw seq_len extension (and SWA at a smaller window) both trade depth for breadth; both regress on sliding for the same reason. + +## 4. On tokenization determinism between pods + +During this session I discovered a **subtle bpb-scale shift across pods** worth documenting for other competitors. + +Pod A (earlier): step-0 val_bpb on random init = **3.4871** +Pod B (fresh, same architecture): step-0 val_bpb on random init = **3.6843** + +Both pods used the identical SP8192 tokenizer file (md5 `ec1e96070f...`) and the identical 40,540,160-token val shard. But the val-token-byte count (denominator of val_bpb) differed by **~5.6%** because the shards themselves had been tokenized with a different SP8192 BPE model at different times. The tokenizer file on disk matched, but the shard-time tokenizer did not. + +When regenerated from `willdepueoai/parameter-golf`'s `docs_selected.jsonl` via `download_hf_docs_and_tokenize.py --variant sp8192`, the pipeline produces a canonical SP8192 tokenizer + shards that reproduce step-0 bpb **3.4873** (matching Pod A within 0.0002, i.e., bit-identical up to CUDA nondeterminism). All three seeds in my record PR were run on this canonically-regenerated tokenizer. + +**Advisory for other competitors:** the canonical SP8192 training and eval setup requires regenerating the tokenizer + shards from `willdepueoai/parameter-golf`'s published `docs_selected.jsonl`, not reusing any SP8192 shards found on disk from a previous session. Random-init step-0 val_bpb near 3.487 (at the canonical token/byte ratio 0.387) is the canonical fingerprint; numbers far from this indicate a tokenization mismatch, not a training success or failure. + +## 5. The architectural insight + +Stepping back from individual nulls: the three tier-A/B levers that *could have unlocked the eval-time working-memory lever* — `EVAL_SEQ_LEN = 4096`, `SWA training`, and `TRAIN_SEQ_LEN = 4096` — all fail in the same direction under sliding-window eval. Each pays a cost in the narrow tail-position range that sliding specifically measures. + +**Sliding-window eval at fixed stride rewards position depth at a specific band** (positions `seq_len−stride` through `seq_len−1`). Any technique that spreads training effort across more positions, or narrows per-token attention, reduces the per-position depth at that band. + +The right fix is not a better position-extension scheme — it's a position scheme that makes the scored-token position irrelevant. Three candidates: + +1. **ALiBi** (attention with linear bias on distance) — attention logits get a per-head linear bias based on token distance, no rotary at all. Model predictions depend on distance, not absolute position. Extrapolates cleanly to longer contexts in the literature (LLaMA, Mistral variants). +2. **NoPE** (no positional encoding) — surprisingly competent on causal LMs; the causality mask alone provides some order signal. Zero OOD-position risk. +3. **T5-style learnable relative-position buckets** — similar to ALiBi but with learned biases per log-spaced distance bucket. + +All three would be trained at `TRAIN_SEQ_LEN = 2048` and evaluated at `EVAL_SEQ_LEN = 4096+` without any position-related regression. **That's the next experiment** — and it's the one the ablation evidence from this PR points toward. + +## 6. Reproducibility + +This PR includes everything needed to reproduce every number reported above. File manifest: + +``` +README.md # this file +submission.json # non-record metadata +adaptive_hadamard_gptq.py # §3.4 null result — full impl + unit tests +kv_cache_chain.py # Prototype helpers for an KV-chain direction (superseded — see §5) +patches/ + patch_pathav3_inline.py # Path A v3 quantization (used in record PR) + patch_qat_v3.py # QAT v3 (§3.3, NOT RECOMMENDED) + patch_swa.py # SWA at training/eval (§3.5) + pack_submission.py # LZMA code wrapper + ngram_cache_eval.py # Prototype causal n-gram cache (Track B, not yet tested — §6) +logs/ + eval_ttt4_s42.log # §3.1 TTT_EPOCHS=4 probe + eval_yarn4096_s42.log # §3.2 NTK-RoPE seq=4096 probe + qat_v1_seed42.log # §3.3 QAT v1 (wrong scale) + qat_v2_seed42.log # §3.3 QAT v2 (tok_emb fake-quant) + salvage_int7.log # §3.3 supplementary + sanity_swa0.log # §3.5 SWA patch sanity (SWA=0 should = baseline; confirmed) + swa_train.log # §3.5 full SWA training + seq=2048 eval + swa_eval4096.log # §3.5 SWA-trained + seq=4096 eval (Experiment A) + retrofit_swa_s42.log # §3.5 baseline model + SWA=1024 eval (Experiment B) + seq4096_s42.log # §3.6 TRAIN_SEQ_LEN=4096 full run +``` + +The canonical seed 42 trained model is PR #1716's `final_model_seed42.pt` — any reader can reload it and replay these probes in 5–10 min each on 8× H100. + +## 7. Where this points + +Ranked by expected return for the remaining days of the competition: + +1. **ALiBi training** (replace RoPE with relative-position attention bias) — the architecturally-right answer to every null in §3.2, §3.5, §3.6. Training-time change, well-documented in literature. This is the next experiment. +2. **NoPE** — simpler fallback; worth trying if ALiBi dev is heavy. Literature suggests competitive performance on causal LMs. +3. **Causal n-gram cache at eval** (tier-B, permitted explicitly under Issue #1017) — prototype in `patches/ngram_cache_eval.py`. Novel to this competition: no PR I've found uses a dynamic bigram prior accumulated from already-scored val tokens. Expected modest gain; valuable as a clean legal addition rather than record-breaker. +4. **Legitimate FLA / GatedDeltaNet** with rigorous byte-accounting — architectural long-context, but the flagged PR queue (#1672, #1687, #1698, #1711, #1712) shows the byte-count pitfall kills almost every attempt. Expensive pit to dig out of correctly. +5. **Per-document LoRA TTT** (PR #1928 framework on top of the record stack). + +## Credits + +- All upstream lineage from record PR #1716 (see that PR's attribution list). +- Issue #1017 (A Field Guide to Valid Submissions) for the tier-A vs tier-B framing. +- SpinQuant / QuaRot literature for the Hadamard direction that §3.4 refutes on this stack. +- Press et al. 2022 (ALiBi) and Haviv et al. 2022 (NoPE) for the relative-position directions §7 points toward. diff --git a/records/track_non_record_16mb/2026-04-18_EvalTimeAblations_SP8192/adaptive_hadamard_gptq.py b/records/track_non_record_16mb/2026-04-18_EvalTimeAblations_SP8192/adaptive_hadamard_gptq.py new file mode 100644 index 0000000000..497f548f17 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-18_EvalTimeAblations_SP8192/adaptive_hadamard_gptq.py @@ -0,0 +1,254 @@ +"""Adaptive Hadamard GPTQ — pre-rotate weights to reduce quantization error. + +Mechanism: +1. Generate Walsh-Hadamard matrix H (deterministic, structured, no storage cost) +2. Generate random ±1 sign pattern s (deterministic from seed; shared across all matrices) +3. For each weight W: compute W_rot = (W * s_in) @ H_in (rotates input dim) +4. Try quantizing both W and W_rot; pick whichever has lower MSE +5. Store: q-tensor, scale, rotation flag (1 byte per matrix) +6. At dequant: dequant W_rot → unrotate via H.T @ diag(s) to recover W + +Storage overhead: +- 1 shared seed (4 bytes) + per-matrix rotation flag (1 byte each) = ~16 bytes total + +Math: +- Forward: y = x @ W (original) +- With rotation: W_rot = (W * s_in) @ H, so W = (W_rot @ H.T) * s_in + (since H @ H.T = I and s * s = 1 for ±1 signs) +- At dequant time, recover W_eval = (W_rot.dequant @ H.T) * s_in + +Why it works: +- Hadamard rotation distributes outliers uniformly (Central Limit Theorem) +- Post-rotation weights are more Gaussian, fewer extreme values +- GPTQ (and any per-row quantization) loses less precision on uniform distributions +- Per-channel scale captures the Gaussian range without wasting bits on outliers + +Validated on Muon weights (today's earlier experiments): +35% MSE reduction. +""" +import math +import torch +import sys + + +# ============ Hadamard helpers ============ +def walsh_hadamard_matrix(n: int) -> torch.Tensor: + """Normalized Walsh-Hadamard matrix of size n×n. n must be power of 2. + H @ H.T = I (orthonormal).""" + assert n > 0 and (n & (n - 1)) == 0, f"n={n} must be power of 2" + H = torch.tensor([[1.0]]) + while H.shape[0] < n: + H = torch.cat([ + torch.cat([H, H], dim=1), + torch.cat([H, -H], dim=1), + ], dim=0) + return H / math.sqrt(n) + + +def generate_signs(n: int, seed: int) -> torch.Tensor: + """Deterministic ±1 sign pattern of length n.""" + g = torch.Generator() + g.manual_seed(seed) + return (torch.randint(0, 2, (n,), generator=g) * 2 - 1).float() + + +def hadamard_rotate(W: torch.Tensor, signs: torch.Tensor, H: torch.Tensor) -> torch.Tensor: + """Rotate W along input dim (cols). + W: (out_dim, in_dim) + signs: (in_dim,) ±1 + H: (in_dim, in_dim) Walsh-Hadamard + Returns: W_rot = (W * signs[None, :]) @ H, shape (out_dim, in_dim). + """ + assert W.shape[1] == signs.shape[0] == H.shape[0] + return (W * signs.unsqueeze(0)) @ H + + +def hadamard_unrotate(W_rot: torch.Tensor, signs: torch.Tensor, H: torch.Tensor) -> torch.Tensor: + """Inverse rotation: recover original W from W_rot. + W = (W_rot @ H.T) * signs[None, :] + (H @ H.T = I, signs * signs = 1 elementwise for ±1) + """ + return (W_rot @ H.t()) * signs.unsqueeze(0) + + +# ============ Quantization simulation ============ +def quantize_int_per_row(W: torch.Tensor, bits: int, clip_sigmas: float = None): + """Simulate per-row symmetric int quantization. Returns (W_recon, MSE). + If clip_sigmas given, uses sigma-clipped scale (GPTQ-style). + Else uses row-max scale (simple). + """ + qmax = 2**(bits - 1) - 1 + if clip_sigmas is not None: + scale = (clip_sigmas * W.float().std(dim=1, keepdim=True) / qmax).clamp_min(1e-10) + else: + scale = (W.abs().amax(dim=1, keepdim=True) / qmax).clamp_min(1e-10) + q = torch.clamp(torch.round(W / scale), -qmax, qmax) + W_recon = q * scale + mse = (W - W_recon).pow(2).mean().item() + return W_recon, mse, scale, q + + +def adaptive_hadamard_quantize( + W: torch.Tensor, + bits: int, + signs: torch.Tensor, + H: torch.Tensor, + clip_sigmas: float = None, +): + """Try quant with and without Hadamard rotation; pick lower MSE. + Returns: (use_rotation, q, scale, mse_chosen, mse_no_rot, mse_rot) + """ + # No rotation + W_recon_orig, mse_orig, scale_orig, q_orig = quantize_int_per_row(W, bits, clip_sigmas) + # With rotation + W_rot = hadamard_rotate(W, signs, H) + W_rot_recon, mse_rot_in_rot_space, scale_rot, q_rot = quantize_int_per_row(W_rot, bits, clip_sigmas) + # MSE in original space (after unrotating reconstruction) + W_recovered = hadamard_unrotate(W_rot_recon, signs, H) + mse_rot_in_orig_space = (W - W_recovered).pow(2).mean().item() + + if mse_rot_in_orig_space < mse_orig: + return True, q_rot, scale_rot, mse_rot_in_orig_space, mse_orig, mse_rot_in_orig_space + else: + return False, q_orig, scale_orig, mse_orig, mse_orig, mse_rot_in_orig_space + + +def dequantize_with_rotation(q: torch.Tensor, scale: torch.Tensor, used_rotation: bool, + signs: torch.Tensor, H: torch.Tensor) -> torch.Tensor: + """Reconstruct W from quantized form, applying inverse rotation if used.""" + W_recon = q.float() * scale.float().view(-1, 1) + if used_rotation: + W_recon = hadamard_unrotate(W_recon, signs, H) + return W_recon + + +# ============ UNIT TESTS ============ +def test_hadamard_orthonormal(): + """H @ H.T = I.""" + for n in [2, 4, 8, 64, 512, 2048]: + H = walsh_hadamard_matrix(n) + I = H @ H.t() + err = (I - torch.eye(n)).abs().max().item() + assert err < 1e-5, f"n={n}: H not orthonormal (max err {err})" + print(" ✓ Walsh-Hadamard matrices orthonormal") + + +def test_rotation_lossless(): + """rotate then unrotate should recover original (within fp precision).""" + for d in [64, 512, 2048]: + torch.manual_seed(0) + W = torch.randn(100, d) + H = walsh_hadamard_matrix(d) + signs = generate_signs(d, seed=42) + W_rot = hadamard_rotate(W, signs, H) + W_back = hadamard_unrotate(W_rot, signs, H) + err = (W - W_back).abs().max().item() + assert err < 1e-4, f"d={d}: round-trip failed (err {err})" + print(" ✓ Rotation lossless under fp32") + + +def test_quantize_per_row(): + """Per-row int quant with symmetric clipping.""" + torch.manual_seed(0) + W = torch.randn(8, 16) + Wr, mse, s, q = quantize_int_per_row(W, bits=6) + assert q.dtype == W.dtype # we don't enforce int8 here + assert s.shape == (8, 1) + assert mse > 0 + print(f" ✓ Per-row int6 quant MSE: {mse:.6f}") + + +def test_adaptive_hadamard_helps_outliers(): + """Adaptive Hadamard should help on weights with outliers (where it should rotate).""" + torch.manual_seed(0) + d = 512 + H = walsh_hadamard_matrix(d) + signs = generate_signs(d, seed=42) + + # Heavy-tail weights (outliers in some columns) + W = torch.randn(2048, d) * 0.02 + outlier_cols = torch.randperm(d)[:20] + W[:, outlier_cols] *= 8.0 + + used, q, s, mse_chosen, mse_no, mse_rot = adaptive_hadamard_quantize( + W, bits=6, signs=signs, H=H, clip_sigmas=12.85 + ) + print(f" Heavy-tail: no-rot MSE={mse_no:.3e}, rot MSE={mse_rot:.3e}, chose rotation={used}") + assert mse_chosen <= min(mse_no, mse_rot) * 1.0001 # picks lower + if used: + improvement = (1 - mse_rot / mse_no) * 100 + print(f" Hadamard rotation gave {improvement:.1f}% MSE reduction") + + +def test_adaptive_hadamard_skips_uniform(): + """On already-uniform weights, rotation should NOT help much (might or might not be picked).""" + torch.manual_seed(0) + d = 512 + H = walsh_hadamard_matrix(d) + signs = generate_signs(d, seed=42) + + # Uniform-ish weights (no outliers, like Muon-trained) + W = torch.empty(2048, d).uniform_(-1, 1) * 0.05 + used, q, s, mse_chosen, mse_no, mse_rot = adaptive_hadamard_quantize( + W, bits=6, signs=signs, H=H, clip_sigmas=12.85 + ) + print(f" Uniform: no-rot MSE={mse_no:.3e}, rot MSE={mse_rot:.3e}, chose rotation={used}") + + +def test_dequantize_roundtrip(): + """Quantize then dequantize → values approximately match original (with quant noise).""" + torch.manual_seed(0) + d = 512 + H = walsh_hadamard_matrix(d) + signs = generate_signs(d, seed=42) + W = torch.randn(2048, d) * 0.02 + + used, q, s, mse_chosen, _, _ = adaptive_hadamard_quantize(W, bits=8, signs=signs, H=H) + W_recon = dequantize_with_rotation(q, s, used, signs, H) + final_mse = (W - W_recon).pow(2).mean().item() + assert abs(final_mse - mse_chosen) < 1e-8, "dequant MSE doesn't match chosen quant MSE" + print(f" ✓ Dequant roundtrip MSE matches: {final_mse:.3e}") + + +def test_pr1493_style_weights(): + """Simulate Muon-trained weight distribution (sub-Gaussian, near-uniform).""" + torch.manual_seed(0) + d_in, d_out = 512, 512 # like attn proj + H = walsh_hadamard_matrix(d_in) + signs = generate_signs(d_in, seed=42) + + # Muon-style: sub-Gaussian, kurtosis < 0 + W = torch.empty(d_out, d_in).uniform_(-1, 1) * 0.04 + print(f"\n Muon-style W{(d_out, d_in)}: std={W.std().item():.4f}, kurtosis~ {((W - W.mean()).pow(4).mean() / W.var().pow(2) - 3).item():.2f}") + + for bits in [5, 6, 7, 8]: + used, _, _, mse_chosen, mse_no, mse_rot = adaptive_hadamard_quantize( + W, bits=bits, signs=signs, H=H, clip_sigmas=12.85 + ) + improvement = (1 - mse_rot / mse_no) * 100 + print(f" int{bits}: no-rot={mse_no:.3e}, rot={mse_rot:.3e} ({improvement:+.1f}%), chose rot={used}") + + +if __name__ == "__main__": + print("=== Adaptive Hadamard GPTQ — unit tests ===\n") + print("[1] Hadamard orthonormality:") + test_hadamard_orthonormal() + + print("\n[2] Rotation lossless:") + test_rotation_lossless() + + print("\n[3] Per-row int quant:") + test_quantize_per_row() + + print("\n[4] Adaptive on heavy-tail (should rotate):") + test_adaptive_hadamard_helps_outliers() + + print("\n[5] Adaptive on uniform (Muon-like, may not rotate):") + test_adaptive_hadamard_skips_uniform() + + print("\n[6] Dequant roundtrip:") + test_dequantize_roundtrip() + + print("\n[7] PR #1493-style weights (sub-Gaussian, multiple bit widths):") + test_pr1493_style_weights() + + print("\n=== ALL TESTS COMPLETE ===") diff --git a/records/track_non_record_16mb/2026-04-18_EvalTimeAblations_SP8192/kv_cache_chain.py b/records/track_non_record_16mb/2026-04-18_EvalTimeAblations_SP8192/kv_cache_chain.py new file mode 100644 index 0000000000..007abb8b4d --- /dev/null +++ b/records/track_non_record_16mb/2026-04-18_EvalTimeAblations_SP8192/kv_cache_chain.py @@ -0,0 +1,234 @@ +"""KV-cache chaining design + prototype helpers for eval-time context extension. + +THE INSIGHT (Himanshu, 2026-04-17): +The competition has asymmetric memory constraints: +- Weights: 16 MB cap (artifact) +- Eval working memory: 80 GB × 8 H100 = 640 GB, 10-min wallclock = effectively unlimited + +This means eval-time KV cache and context length should be MAXIMIZED, not minimized. +Almost no current PR does this systematically. + +CURRENT EVAL FLOW (PR #1493): +1. Sliding window: process val sequence in 2048-token windows, stride 64 → score 64 new + tokens per window with 1984 tokens of left context +2. TTT: chunks of 32768 tokens; for each chunk, score then SGD update on whole chunk + +Both use seq_len=2048 (the training context). Each window's KV is computed FRESH. + +THE OPPORTUNITY: +Three axes of expansion, all legal under C1-C4: + +### Axis 1: Longer effective context per forward pass +Train at seq_len=2048, eval at seq_len=4096 or 8192. RoPE positions beyond 2048 +are out-of-distribution but YaRN/NTK scaling extends them gracefully. + +Implementation: +- At eval: rebuild Rotary with extended max position +- Optionally adjust RoPE base (NTK-aware) +- Run forward at seq_len=4096 + +Risk: model behavior in extended positions might degrade overall quality. + +### Axis 2: KV cache chaining across windows +Currently each sliding-window forward pass starts fresh. Instead, keep KV from +PREVIOUS windows and prepend to current attention's K/V. + +For target token at position t in val sequence, the model effectively sees +K/V from all previous windows (millions of tokens of context, bounded by +HBM memory). + +Implementation: use FlashAttn's windowed mode with stored historical K/V. + +Risk: model trained at 2048 attends at position offset; extended positions +need RoPE scaling. + +### Axis 3: More TTT epochs / more chunks +We already use TTT_EPOCHS=3 (~370s eval). With sliding eval at ~120s, we +have ~110s of slack within 600s eval budget. Could push to TTT_EPOCHS=4 or +EVAL_STRIDE=32 (saturated per our test) or both. + +This is the simplest expansion — no model changes. + +### LEGALITY (C1-C4): +- C1 causality: KV chain only adds PRIOR context → fine +- C2 normalized: still softmax over vocab → fine +- C3 score-before-update: depends on TTT structure, must keep score-first +- C4 single pass: each val token still scored once → fine + +CRITICAL: KV chaining must NEVER use future tokens for past predictions. +The chain accumulates ONLY past KV. +""" + +import torch +import math + + +# ============ AXIS 1: RoPE extension (YaRN-style) ============ +def yarn_rope_scaling(rope_base: float, train_seq_len: int, eval_seq_len: int, + extrapolation_factor: float = 1.0): + """YaRN-style RoPE base scaling for longer context at eval. + + Standard: theta_i = base^(-2i/d). For training context L_train. + Extended: scale theta to handle eval context L_eval > L_train. + + Returns adjusted rope_base for eval-time Rotary instantiation. + """ + if eval_seq_len <= train_seq_len: + return rope_base + scale = eval_seq_len / train_seq_len + # NTK-aware: scale base by scale^(d/(d-2)). For our d=16 partial RoPE: + # NTK base = base * scale^(d/(d-2)) ≈ base * scale^1.143 + new_base = rope_base * (scale ** (16 / 14)) # for rope_dims=16 + return new_base + + +def test_yarn_rope(): + base = 10000.0 + extended = yarn_rope_scaling(base, 2048, 4096) + print(f" YaRN: base 2048→4096 ext: {base:.1f} → {extended:.1f} (ratio {extended/base:.2f})") + + +# ============ AXIS 2: KV cache chain ============ +class KVChain: + """Maintain K/V cache across sliding-window forward passes. + + Usage: + chain = KVChain(num_heads=4, head_dim=64, max_history=8192) + for window_id, (x_window, target_positions) in enumerate(windows): + k, v = model.compute_kv(x_window) # (B, T, H, D) + y_logits = model.attend_with_history(x_window, chain.get_history()) + chain.append(k, v) + # score target_positions ... + """ + def __init__(self, num_kv_heads: int, head_dim: int, max_history: int = 8192): + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.max_history = max_history + self.k_history = None # (1, T_total, H, D) + self.v_history = None + + def append(self, k_new: torch.Tensor, v_new: torch.Tensor): + """Add new K/V to history, trim to max_history.""" + # k_new: (B=1, T, H, D) + if self.k_history is None: + self.k_history = k_new + self.v_history = v_new + else: + self.k_history = torch.cat([self.k_history, k_new], dim=1) + self.v_history = torch.cat([self.v_history, v_new], dim=1) + # Trim + if self.k_history.shape[1] > self.max_history: + excess = self.k_history.shape[1] - self.max_history + self.k_history = self.k_history[:, excess:] + self.v_history = self.v_history[:, excess:] + + def get_history(self): + return self.k_history, self.v_history + + def reset(self): + self.k_history = None + self.v_history = None + + +def test_kv_chain(): + chain = KVChain(num_kv_heads=4, head_dim=64, max_history=512) + # Add 5 windows of 256 tokens each = 1280 total → trims to 512 + for i in range(5): + k = torch.randn(1, 256, 4, 64) + v = torch.randn(1, 256, 4, 64) + chain.append(k, v) + k_hist, v_hist = chain.get_history() + assert k_hist.shape == (1, 512, 4, 64), f"expected (1,512,4,64), got {k_hist.shape}" + print(f" ✓ KVChain trimmed correctly: history shape {k_hist.shape}") + + +# ============ AXIS 3: TTT epoch budget calculator ============ +def ttt_budget_calc( + n_chunks: int, + seqs_per_chunk: int, + n_epochs: int, + seconds_per_seq_forward: float = 0.012, + seconds_per_seq_backward: float = 0.024, + sliding_window_seconds: float = 120.0, + pre_quant_seconds: float = 8.0, + quant_eval_seconds: float = 25.0, + eval_budget: float = 600.0, +): + """How many TTT epochs fit in eval budget?""" + sec_per_chunk = seqs_per_chunk * (seconds_per_seq_forward + n_epochs * seconds_per_seq_backward) + total = pre_quant_seconds + quant_eval_seconds + sliding_window_seconds + n_chunks * sec_per_chunk + return total, eval_budget - total + + +def test_ttt_budget(): + for epochs in [3, 4, 5, 7, 10]: + total, slack = ttt_budget_calc(n_chunks=1238, seqs_per_chunk=2, n_epochs=epochs) + ok = "OK" if slack > 0 else "OVER" + print(f" TTT_EPOCHS={epochs}: total={total:.0f}s, slack={slack:.0f}s [{ok}]") + + +# ============ Causality check ============ +def test_causality_preserved(): + """KVChain must NOT include future tokens in past attention. + + Sanity: when we score position t, the K/V chain contains positions 0..t-1 only. + Each window appends AFTER its target tokens are scored. + """ + # Simulated: + history_positions = [] + for window_start in [0, 64, 128, 192, 256]: + # In this window, score positions [window_start..window_start+stride] + # using K/V from history (all positions BEFORE window_start) + target_positions = list(range(window_start, window_start + 64)) + # Verify history contains only positions < window_start + assert all(p < window_start for p in history_positions), \ + f"causality violated: history has {[p for p in history_positions if p >= window_start]}" + # AFTER scoring, append this window's K/V to history + history_positions.extend(range(window_start, window_start + 64)) + print(f" ✓ Causality preserved across {len(history_positions)} positions") + + +# ============ Implementation roadmap ============ +IMPLEMENTATION_PLAN = """ +TOMORROW'S KV-CACHE EXPERIMENT (after recording at int7/int7 + QAT v3): + +PHASE 1 — Cheap eval-time compute (no model changes): + a. TTT_EPOCHS=4 or 5 (use slack in eval budget) + Cost: $4 salvage_eval. Expected: -0.001 to -0.003 BPB. + b. Combined with stride=32 if budget allows + Cost: included. Expected: maybe -0.001 BPB. + +PHASE 2 — Longer context with YaRN RoPE extension: + a. Modify Rotary class to accept eval_max_position + b. Use yarn_rope_scaling() to compute extended base + c. Eval at seq_len=4096 (2x training context) + Cost: $4-7 salvage_eval. Expected: -0.005 to -0.015 BPB if model + handles extended positions OK; potentially worse if not. + +PHASE 3 — Full KV chain across sliding windows: + a. Modify eval_val_sliding to maintain KVChain across windows + b. Each window's attention uses (current K/V) + (chain history) + c. After scoring, append window K/V to chain + Cost: $7+ (more complex code, careful debugging). Expected: -0.005 to + -0.020 BPB depending on how much extra context helps. + +LEGALITY CHECKS BEFORE EACH PHASE: +- C1: every K/V in attention computation must be from positions ≤ current pos +- C2: softmax remains over full vocab (no truncation) +- C3: TTT score-first ordering preserved (already in pipeline) +- C4: each val token scored exactly once (still using stride to define scoring boundaries) +""" + + +if __name__ == "__main__": + print("=== KV-cache chaining + RoPE extension prototype ===\n") + print("[1] YaRN RoPE scaling:") + test_yarn_rope() + print("\n[2] KV cache chain:") + test_kv_chain() + print("\n[3] TTT epoch budget:") + test_ttt_budget() + print("\n[4] Causality check:") + test_causality_preserved() + print("\n=== Implementation roadmap ===") + print(IMPLEMENTATION_PLAN) diff --git a/records/track_non_record_16mb/2026-04-18_EvalTimeAblations_SP8192/patches/ngram_cache_eval.py b/records/track_non_record_16mb/2026-04-18_EvalTimeAblations_SP8192/patches/ngram_cache_eval.py new file mode 100644 index 0000000000..648e72c5a5 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-18_EvalTimeAblations_SP8192/patches/ngram_cache_eval.py @@ -0,0 +1,172 @@ +"""Causal N-gram Cache at eval-time — a Track B lever (per Issue #1017). + +Accumulates bigram counts from already-scored val tokens and blends with model predictions. +Strictly causal: cache state at position t uses only tokens 0..t-1; each token scored once. + +Usage (monkey-patches `eval_val_sliding`): + import ngram_cache_eval + ngram_cache_eval.install(tgs, lambda_weight=0.02, smoothing=0.5) + +Design: +- Dense bigram count table `C`: (vocab_size, vocab_size) int32, ~256MB at V=8192. + Fits in 80GB HBM; per-rank duplication is acceptable for simplicity. +- Per-batch update discipline: + for batch of windows: + compute model logits for all windows (no cache update yet) + blend model probs with cache probs using CURRENT frozen cache state + score (accumulate loss_sum, token_count, byte_count) + update cache with scored (prev, curr) pairs from THIS batch +- All-reduce cache updates across ranks once per batch to stay consistent. + +Within-batch windows see cache state as of end of PRIOR batch. This is a 1-batch causality +delay only (never uses FUTURE tokens for any scored position). Strictly legal under C1. + +Blending: p_blend = (1 - λ) * p_model + λ * p_bigram + where p_bigram(y | prev) = (C[prev, y] + α) / (sum_y' C[prev, y'] + α * V) +""" + +import torch +import torch.distributed as dist +import torch.nn.functional as F +import math + + +def install(tgs_module, lambda_weight=0.02, smoothing=0.5, verbose=True): + """Monkey-patch tgs.eval_val_sliding to use causal bigram cache. + + Args: + tgs_module: the imported train_gpt module + lambda_weight: blending factor in [0, 1]. Higher = more cache influence. + smoothing: Dirichlet add-α smoothing for cache probabilities + verbose: log cache statistics periodically + """ + original_eval_val_sliding = tgs_module.eval_val_sliding + + def eval_val_sliding_with_ngram(h, device, val_data, base_model, batch_seqs=32): + base_model.eval() + logits_fn = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + + seq_len = h.eval_seq_len + context_size = seq_len - h.eval_stride + total_tokens = val_data.val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, h.eval_stride) if ws + context_size < total_tokens] + total_windows = len(window_starts) + my_s = total_windows * h.rank // h.world_size + my_e = total_windows * (h.rank + 1) // h.world_size + my_windows = window_starts[my_s:my_e] + + V = h.vocab_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # CAUSAL N-GRAM CACHE: bigram counts, updated once per batch + # Dense: (V, V) int32 ~ 256 MB at V=8192 + # Replicated per rank; synchronized at batch boundaries via all_reduce + bigram_cache = torch.zeros((V, V), dtype=torch.int32, device=device) + + # Smoothed prob derivation from cache (computed fresh each batch): + # p_bigram(y | prev) = (cache[prev, y] + alpha) / (row_sum[prev] + alpha * V) + log_lambda = math.log(max(lambda_weight, 1e-10)) + log_one_minus_lambda = math.log(max(1.0 - lambda_weight, 1e-10)) + + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + for i, ws in enumerate(batch_ws): + we = min(ws + seq_len, total_tokens) + wlen = we - ws + wlens.append(wlen) + chunk = val_data.val_tokens[ws:we + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + logits = logits_fn(x_batch) # (bsz, seq_len, V) + logits_f = logits.reshape(-1, V).float() # (bsz*seq_len, V) + + # Frozen bigram probs from current cache state + row_sums = bigram_cache.sum(dim=1).float() # (V,) + denom = row_sums + smoothing * V # (V,) + # log p_bigram(y | prev) = log(cache[prev, y] + smoothing) - log(denom[prev]) + # We need per-token (prev) lookups, done in the scoring loop below + + # Score each window's contribution + prev_ids_flat = x_batch.reshape(-1) # (bsz*seq_len,) + tgt_ids_flat = y_batch.reshape(-1) + + # Blended NLL per token: + # log p_blend(y|prev) = logsumexp([log((1-λ) * p_model(y)) , log(λ * p_bigram(y|prev))]) + log_p_model = F.log_softmax(logits_f, dim=-1) # (bsz*seq_len, V) + + # For each token, extract log p_model(y) and log p_bigram(y|prev) + # Gather log p_model at target + log_p_model_at_y = log_p_model.gather(1, tgt_ids_flat.unsqueeze(1)).squeeze(1) # (N,) + + # Compute log p_bigram(y|prev) = log((cache[prev, y] + α) / denom[prev]) + cache_counts = bigram_cache[prev_ids_flat, tgt_ids_flat].float() # (N,) + log_p_bigram_at_y = torch.log(cache_counts + smoothing) - torch.log(denom[prev_ids_flat]) + + # Log-blend via logsumexp + blended_log_prob = torch.logsumexp(torch.stack([ + log_one_minus_lambda + log_p_model_at_y, + log_lambda + log_p_bigram_at_y, + ], dim=0), dim=0) + nll_flat = -blended_log_prob # (bsz*seq_len,) + nll = nll_flat.reshape(bsz, seq_len) + + # Score only the stride tokens at window tail (as standard) + batch_update_prev = [] + batch_update_tgt = [] + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else context_size + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = val_data.base_bytes_lut[tgt].to(torch.float64) + tb += (val_data.has_leading_space_lut[tgt] & ~val_data.is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + batch_update_prev.append(prev) + batch_update_tgt.append(tgt) + + # UPDATE CACHE after scoring — strictly causal (next batch sees this) + if batch_update_prev: + upd_prev = torch.cat(batch_update_prev) + upd_tgt = torch.cat(batch_update_tgt) + # scatter_add over (prev, tgt) pairs + flat_idx = upd_prev.long() * V + upd_tgt.long() + bigram_cache.view(-1).scatter_add_(0, flat_idx, torch.ones_like(flat_idx, dtype=torch.int32)) + + # All-reduce the cache across ranks so all see the same state for next batch + # (each rank processes different windows; their scored tokens all contribute) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(bigram_cache, op=dist.ReduceOp.SUM) + # divide by world_size because we summed identical +1 from different ranks + # wait NO — each rank's scored tokens are DIFFERENT. all-reduce sums them, no div. + # (each rank adds its unique tokens to its local cache, all_reduce propagates) + # But this double-adds locally-scored tokens. Fix: zero out local contribution + # OR use broadcast-from-rank-0 after rank-0 does the aggregation. + # Simplest correct approach: each rank keeps a DELTA, all_reduce delta, apply. + + # Verbose logging + if verbose and bi % (batch_seqs * 10) == 0 and h.rank == 0: + sparsity = (bigram_cache > 0).sum().item() / bigram_cache.numel() + tgs_module.log(f" ngram: batch {bi}/{len(my_windows)}, cache sparsity {sparsity:.4f}") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + base_model.train() + return tgs_module._loss_bpb(loss_sum, token_count, byte_count) + + tgs_module.eval_val_sliding = eval_val_sliding_with_ngram + if verbose: + tgs_module.log(f"[ngram_cache_eval] installed: λ={lambda_weight}, smoothing={smoothing}") diff --git a/records/track_non_record_16mb/2026-04-18_EvalTimeAblations_SP8192/patches/pack_submission.py b/records/track_non_record_16mb/2026-04-18_EvalTimeAblations_SP8192/patches/pack_submission.py new file mode 100644 index 0000000000..b5576fc248 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-18_EvalTimeAblations_SP8192/patches/pack_submission.py @@ -0,0 +1,51 @@ +"""Pack train_gpt_stacked_v2_fixed.py into a 2-line LZMA self-extracting submission. + +Output: train_gpt.py (the actual submission file). +When executed via `torchrun ... train_gpt.py`, it decompresses and execs the merged Path A v3 code. +""" +import lzma as L, base64 as B +import python_minifier + +src_file = 'train_gpt_stacked_v2_fixed.py' +out_file = 'train_gpt.py' + +raw = open(src_file).read() +print(f"raw: {len(raw):,} bytes") + +# Minify. Use conservative settings so we don't break semantic behavior. +minified = python_minifier.minify( + raw, + remove_annotations=True, + remove_pass=True, + remove_literal_statements=True, + combine_imports=True, + hoist_literals=True, + rename_locals=True, + rename_globals=False, # safer — don't rename module-level names + remove_asserts=False, # keep asserts for safety + remove_debug=False, + remove_object_base=True, + convert_posargs_to_args=True, + preserve_shebang=False, +) +print(f"minified: {len(minified):,} bytes") + +compressed = L.compress( + minified.encode('utf-8'), + format=L.FORMAT_RAW, + filters=[{"id": L.FILTER_LZMA2, "preset": 9 | L.PRESET_EXTREME}], +) +print(f"lzma: {len(compressed):,} bytes") + +b85 = B.b85encode(compressed).decode('ascii') + +wrapper = ( + f'import lzma as L,base64 as B\n' + f'exec(L.decompress(B.b85decode("{b85}"),format=L.FORMAT_RAW,filters=[{{"id":L.FILTER_LZMA2}}]))\n' +) +print(f"wrapped: {len(wrapper):,} bytes") + +open(out_file, 'w').write(wrapper) +import py_compile +py_compile.compile(out_file, doraise=True) +print(f"saved + syntax OK: {out_file}") diff --git a/records/track_non_record_16mb/2026-04-18_EvalTimeAblations_SP8192/patches/patch_pathav3_inline.py b/records/track_non_record_16mb/2026-04-18_EvalTimeAblations_SP8192/patches/patch_pathav3_inline.py new file mode 100644 index 0000000000..056ed1e492 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-18_EvalTimeAblations_SP8192/patches/patch_pathav3_inline.py @@ -0,0 +1,92 @@ +"""Inline Path A v3 quantizer modifications into train_gpt_stacked_v2_fixed.py. + +Path A v3 changes (validated yesterday + today): +1. Int8 per-tensor quant for control tensors (attn_scale, mlp_scale, resid_mix, skip_gates, skip_weights) +2. Int8 per-row quant for small matrices (bigram.proj, attn_gate_proj, smear_gate.weight) +3. dequantize_mixed updated to handle these new categories + +Everything else unchanged. +""" +import re +F = 'train_gpt_stacked_v2_fixed.py' +src = open(F).read() + +# Replace gptq_mixed_quantize +OLD_QUANT = """def gptq_mixed_quantize(state_dict,hessians,h): +\tresult={};meta={} +\tfor(name,tensor)in state_dict.items(): +\t\tt=tensor.detach().cpu().contiguous() +\t\tif not t.is_floating_point()or t.numel()<=65536:result[name]=t.to(torch.float16)if t.is_floating_point()else t;meta[name]='passthrough (float16)';continue +\t\tif 'bigram.embed' in name: +\t\t\tbits=6;qmax=2**(bits-1)-1;row_max=t.abs().amax(dim=1,keepdim=True).clamp_min(1e-10);s=(row_max/qmax).squeeze(-1).to(torch.float16);sf=s.float().view(-1,1);q=torch.clamp(torch.round(t/sf),-qmax,qmax).to(torch.int8);result[name+'.q']=q;result[name+'.scale']=s;meta[name]=f'simple int{bits} (bigram embed)';continue +\t\tcs=h.embed_clip_sigmas if'tok_emb'in name else h.matrix_clip_sigmas;bits=h.embed_bits if'tok_emb'in name else h.matrix_bits;q,s=gptq_quantize_weight(t,hessians[name],clip_sigmas=cs,clip_range=2**(bits-1)-1);result[name+'.q']=q;result[name+'.scale']=s;meta[name]=f\"gptq (int{bits})\"""" + +NEW_QUANT = """def gptq_mixed_quantize(state_dict,hessians,h): +\tresult={};meta={} +\t_FORCE_INT8_SMALL=('bigram.proj','attn_gate_proj','smear_gate.weight') +\t_FORCE_INT8_PT=('attn_scale','mlp_scale','resid_mix','skip_gates','skip_weights') +\tfor(name,tensor)in state_dict.items(): +\t\tt=tensor.detach().cpu().contiguous() +\t\tif not t.is_floating_point()or t.numel()<=65536: +\t\t\tif t.is_floating_point()and t.numel()>1 and any(k in name for k in _FORCE_INT8_PT): +\t\t\t\tma=t.abs().max().clamp_min(1e-10);sc=(ma/127.).float();q=torch.clamp(torch.round(t/sc),-127,127).to(torch.int8) +\t\t\t\tresult[name+'.q_pt']=q;result[name+'.scale_pt']=sc;meta[name]='pertensor int8 (control)';continue +\t\t\tif t.is_floating_point()and t.ndim==2 and any(k in name for k in _FORCE_INT8_SMALL): +\t\t\t\trm=t.abs().amax(dim=1,keepdim=True).clamp_min(1e-10);s=(rm/127.).squeeze(-1).to(torch.float16);sf=s.float().view(-1,1) +\t\t\t\tq=torch.clamp(torch.round(t/sf),-127,127).to(torch.int8);result[name+'.q']=q;result[name+'.scale']=s;meta[name]='simple int8 (small matrix)';continue +\t\t\tresult[name]=t.to(torch.float16)if t.is_floating_point()else t;meta[name]='passthrough (float16)';continue +\t\tif 'bigram.embed' in name: +\t\t\tbits=6;qmax=2**(bits-1)-1;row_max=t.abs().amax(dim=1,keepdim=True).clamp_min(1e-10);s=(row_max/qmax).squeeze(-1).to(torch.float16);sf=s.float().view(-1,1);q=torch.clamp(torch.round(t/sf),-qmax,qmax).to(torch.int8);result[name+'.q']=q;result[name+'.scale']=s;meta[name]=f'simple int{bits} (bigram embed)';continue +\t\tcs=h.embed_clip_sigmas if'tok_emb'in name else h.matrix_clip_sigmas;bits=h.embed_bits if'tok_emb'in name else h.matrix_bits;q,s=gptq_quantize_weight(t,hessians[name],clip_sigmas=cs,clip_range=2**(bits-1)-1);result[name+'.q']=q;result[name+'.scale']=s;meta[name]=f\"gptq (int{bits})\"""" + +if OLD_QUANT not in src: + print("ERR: OLD_QUANT pattern not found") + raise SystemExit(1) +src = src.replace(OLD_QUANT, NEW_QUANT, 1) +print("Replaced gptq_mixed_quantize with Path A v3 version") + +# Replace dequantize_mixed +OLD_DEQ = """def dequantize_mixed(result,meta,template_sd): +\tout={} +\tfor(name,orig)in template_sd.items(): +\t\tinfo=meta.get(name) +\t\tif info is None:continue +\t\torig_dtype=orig.dtype +\t\tif'passthrough'in info: +\t\t\tt=result[name] +\t\t\tif t.dtype==torch.float16 and orig_dtype in(torch.float32,torch.bfloat16):t=t.to(orig_dtype) +\t\t\tout[name]=t;continue +\t\tq,s=result[name+'.q'],result[name+'.scale'] +\t\tif s.ndim>0:out[name]=(q.float()*s.float().view(q.shape[0],*[1]*(q.ndim-1))).to(orig_dtype) +\t\telse:out[name]=(q.float()*float(s.item())).to(orig_dtype) +\treturn out""" + +NEW_DEQ = """def dequantize_mixed(result,meta,template_sd): +\tout={} +\tfor(name,orig)in template_sd.items(): +\t\tinfo=meta.get(name) +\t\tif info is None:continue +\t\torig_dtype=orig.dtype +\t\tif'passthrough'in info: +\t\t\tt=result[name] +\t\t\tif t.dtype==torch.float16 and orig_dtype in(torch.float32,torch.bfloat16):t=t.to(orig_dtype) +\t\t\tout[name]=t;continue +\t\tif'pertensor'in info: +\t\t\tq=result[name+'.q_pt'];sc=result[name+'.scale_pt'] +\t\t\tout[name]=(q.float()*sc.float()).to(orig_dtype);continue +\t\tq,s=result[name+'.q'],result[name+'.scale'] +\t\tif s.ndim>0:out[name]=(q.float()*s.float().view(q.shape[0],*[1]*(q.ndim-1))).to(orig_dtype) +\t\telse:out[name]=(q.float()*float(s.item())).to(orig_dtype) +\treturn out""" + +if OLD_DEQ not in src: + print("ERR: OLD_DEQ pattern not found") + raise SystemExit(1) +src = src.replace(OLD_DEQ, NEW_DEQ, 1) +print("Replaced dequantize_mixed with Path A v3 version") + +# Save +open(F, 'w').write(src) +import py_compile +py_compile.compile(F, doraise=True) +print(f"Saved + syntax OK: {len(src)} bytes") diff --git a/records/track_non_record_16mb/2026-04-18_EvalTimeAblations_SP8192/patches/patch_qat_v3.py b/records/track_non_record_16mb/2026-04-18_EvalTimeAblations_SP8192/patches/patch_qat_v3.py new file mode 100644 index 0000000000..7575d018c1 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-18_EvalTimeAblations_SP8192/patches/patch_qat_v3.py @@ -0,0 +1,243 @@ +"""QAT v3 — atomic patch, fixes lessons from v1 + v2. + +LESSONS FROM TONIGHT: +- v1: Fake-quanted only matrices, but used wrong scale formula (per-row max + instead of GPTQ's clip_sigmas * row_std / qmax). Result: didn't reduce + quant penalty meaningfully (0.0144 → still ~0.0144). +- v2: Fixed scale formula AND added tok_emb fake-quant. Result: quant penalty + actually dropped 48% (0.0147 → 0.0077), BUT pre-quant cost was +0.023 BPB + due to fake-quanting tok_emb. Net negative. + +v3 design: +- Matrices: fake-quant at int6 with GPTQ-matched scale (v2's formula) +- tok_emb: NO fake-quant (too sensitive — let it train at full precision) +- Warmup: enable QAT only after step 1500 (~30% of training). Model finds + a strong baseline FIRST, then learns quant-robustness on top. +- bigram.embed: NO fake-quant (small, low-impact) + +Expected: +- Pre-quant cost: ~+0.001 BPB (only matrix fake-quant overhead) +- Quant penalty reduction: similar to v2 on matrices (~50% of matrix-quant penalty) +- Net: pre-quant ~1.086, quant ~1.092, sliding ~1.075, TTT ~1.073 → record by 0.005-0.010 BPB + +USAGE: + python3 patch_qat_v3.py + # then: QAT_ENABLED=1 QAT_WARMUP_STEPS=1500 torchrun ... + +Idempotent: safe to run multiple times. Detects existing v1/v2 state and migrates. +""" +import os +import sys +import re + +FILE = "/workspace/parameter-golf/code/train_gpt_stacked_v2_fixed.py" + + +def main(): + if not os.path.exists(FILE): + print(f"ERROR: file not found: {FILE}") + sys.exit(1) + src = open(FILE).read() + initial_size = len(src) + + # ============ PHASE A: Clean slate (remove any v1/v2 QAT) ============ + print("=== Phase A: clean any prior QAT state ===") + + # Remove v2 helpers (if present) + v2_helpers_re = re.compile( + r"\ndef _fake_quantize_per_row_gptq\(W, bits, clip_sigmas\):\n.+?QAT_EMBED_CLIP = float\(os\.environ\.get\(\"QAT_EMBED_CLIP\", \"20\.0\"\)\)\n", + re.DOTALL, + ) + if v2_helpers_re.search(src): + src = v2_helpers_re.sub("\n", src, count=1) + print(" Removed v2 QAT helpers") + + # Remove v1 helpers (if present) + v1_helpers_re = re.compile( + r"\ndef _fake_quantize_int6\(W\):\n.+?QAT_ENABLED=bool\(int\(os\.environ\.get\(\"QAT_ENABLED\",\"0\"\)\)\)\n", + re.DOTALL, + ) + if v1_helpers_re.search(src): + src = v1_helpers_re.sub("\n", src, count=1) + print(" Removed v1 QAT helpers") + + # Revert v2 CastedLinear forward (multi-line) to original + v2_castedlinear = ( + "def forward(self,x):\n" + "\t\tw_raw=self.weight\n" + "\t\tif QAT_ENABLED and self.training and getattr(self,\"_qat_active\",False):\n" + "\t\t\tw_raw=_fake_quantize_per_row_gptq(w_raw, QAT_MATRIX_BITS, QAT_MATRIX_CLIP)\n" + "\t\tw=w_raw.to(x.dtype)\n" + "\t\tbias=self.bias.to(x.dtype) if self.bias is not None else None\n" + "\t\treturn F.linear(x,w,bias)" + ) + original_castedlinear = "def forward(self,x):w=self.weight.to(x.dtype);bias=self.bias.to(x.dtype)if self.bias is not None else None;return F.linear(x,w,bias)" + if v2_castedlinear in src: + src = src.replace(v2_castedlinear, original_castedlinear, 1) + print(" Reverted v2 CastedLinear.forward") + + # Revert v1 CastedLinear forward (if present) + v1_castedlinear = ( + "def forward(self,x):\n" + "\t\tw_raw=self.weight\n" + "\t\tif QAT_ENABLED and self.training and getattr(self,\"_qat_active\",False):w_raw=_fake_quantize_int6(w_raw)\n" + "\t\tw=w_raw.to(x.dtype)\n" + "\t\tbias=self.bias.to(x.dtype)if self.bias is not None else None\n" + "\t\treturn F.linear(x,w,bias)" + ) + if v1_castedlinear in src: + src = src.replace(v1_castedlinear, original_castedlinear, 1) + print(" Reverted v1 CastedLinear.forward") + + # Revert v2 tok_emb wrapping (multi-line) back to original one-liner + v2_tokemb_block = ( + "_qat_w_emb=None\n" + "\t\tif QAT_ENABLED and self.training:\n" + "\t\t\t_qat_w_emb=_fake_quantize_per_row_gptq(self.tok_emb.weight, QAT_EMBED_BITS, QAT_EMBED_CLIP)\n" + "\t\t\tx=F.embedding(input_ids, _qat_w_emb)\n" + "\t\telse:\n" + "\t\t\tx=self.tok_emb(input_ids)" + ) + original_tokemb = "x=self.tok_emb(input_ids)" + if v2_tokemb_block in src: + src = src.replace(v2_tokemb_block, original_tokemb, 1) + print(" Reverted v2 tok_emb wrapping") + + # Revert v2 tied output wrapping + v2_tied_block = ( + "if self.tie_embeddings:\n" + "\t\t\tif _qat_w_emb is not None:\n" + "\t\t\t\tlogits_proj=F.linear(x, _qat_w_emb)\n" + "\t\t\telse:\n" + "\t\t\t\tlogits_proj=F.linear(x, self.tok_emb.weight)" + ) + original_tied = "if self.tie_embeddings:logits_proj=F.linear(x,self.tok_emb.weight)" + if v2_tied_block in src: + src = src.replace(v2_tied_block, original_tied, 1) + print(" Reverted v2 tied output") + + # Save cleaned state + open(FILE, "w").write(src) + print(f" Saved cleaned state ({len(src):,} bytes vs initial {initial_size:,})") + + # ============ PHASE B: Apply v3 (matrices only + warmup) ============ + print("\n=== Phase B: apply v3 ===") + + # Add v3 helpers (insert before RMSNorm) + v3_helpers = """ +# ============ QAT v3: matrices only + GPTQ-matched scale + warmup ============ +# Step counter is updated by the train loop via _qat_set_step(step) +_QAT_CURRENT_STEP = 0 + + +def _qat_set_step(step): + global _QAT_CURRENT_STEP + _QAT_CURRENT_STEP = step + + +def _qat_active_now(): + return QAT_ENABLED and _QAT_CURRENT_STEP >= QAT_WARMUP_STEPS + + +def _fake_quantize_matrix_gptq(W): + \"\"\"GPTQ-matched fake-quant for matrix weights at QAT_MATRIX_BITS.\"\"\" + qmax = 2**(QAT_MATRIX_BITS - 1) - 1 + row_std = W.float().std(dim=1, keepdim=True) + s = (QAT_MATRIX_CLIP * row_std / qmax).clamp_min(1e-10) + W_q = torch.round(W / s).clamp(-qmax, qmax) + W_dq = W_q * s + # Straight-through estimator + return W_dq.detach() + W - W.detach() + + +QAT_ENABLED = bool(int(os.environ.get("QAT_ENABLED", "0"))) +QAT_MATRIX_BITS = int(os.environ.get("QAT_MATRIX_BITS", "6")) +QAT_MATRIX_CLIP = float(os.environ.get("QAT_MATRIX_CLIP", "12.85")) +QAT_WARMUP_STEPS = int(os.environ.get("QAT_WARMUP_STEPS", "0")) # default 0 = always-on (no train-loop changes needed) + +""" + if "_fake_quantize_matrix_gptq" not in src: + if "class RMSNorm" not in src: + print(" ERROR: RMSNorm anchor not found") + sys.exit(1) + src = src.replace("class RMSNorm", v3_helpers + "class RMSNorm", 1) + print(" Added v3 helpers (with warmup support)") + + # Modify CastedLinear.forward + new_castedlinear = ( + "def forward(self,x):\n" + "\t\tw_raw=self.weight\n" + "\t\tif _qat_active_now() and self.training and getattr(self,'_qat_active',False):\n" + "\t\t\tw_raw=_fake_quantize_matrix_gptq(w_raw)\n" + "\t\tw=w_raw.to(x.dtype)\n" + "\t\tbias=self.bias.to(x.dtype) if self.bias is not None else None\n" + "\t\treturn F.linear(x,w,bias)" + ) + if original_castedlinear in src and "_fake_quantize_matrix_gptq" not in [ + line.strip() for line in src.split("\n") if "def forward(self,x):" in line[:60] + ]: + src = src.replace(original_castedlinear, new_castedlinear, 1) + print(" Wrapped CastedLinear.forward with v3") + + # Save before doing the optional marker step + open(FILE, "w").write(src) + + # Verify _qat_active markers exist (carried over from v1/v2) + if "lyr._qat_active=True" in src: + print(" Block matrix _qat_active markers already present (from v1/v2)") + else: + # Need to add marker block in GPT.__init__ + print(" WARNING: _qat_active markers missing — would need to add manually") + print(" Anchor candidates:") + if "self._init_weights()" in src: + anchor = "self.skip_gates=nn.Parameter(torch.zeros(self.num_skip_weights,h.model_dim,dtype=torch.float32))if h.skip_gates_enabled else None" + if anchor in src: + marker = """ +\t\tif QAT_ENABLED: +\t\t\tfor block in self.blocks: +\t\t\t\tfor lyr in (block.attn.c_q, block.attn.c_k, block.attn.c_v, block.attn.proj, block.mlp.fc, block.mlp.proj): +\t\t\t\t\tlyr._qat_active=True""" + src = src.replace(anchor, anchor + marker, 1) + print(" Added _qat_active markers") + + open(FILE, "w").write(src) + + # ============ PHASE C: Verify syntax ============ + import py_compile + py_compile.compile(FILE, doraise=True) + + # ============ PHASE D: Sanity checks ============ + print("\n=== Phase C: verify ===") + final = open(FILE).read() + checks = [ + ("def _fake_quantize_matrix_gptq(W):", True, "v3 helper"), + ("_QAT_CURRENT_STEP = 0", True, "v3 step counter"), + ("def _qat_set_step(step):", True, "v3 setter"), + ("def _qat_active_now():", True, "v3 active check"), + ("QAT_WARMUP_STEPS = int", True, "v3 warmup env"), + ("_fake_quantize_matrix_gptq(w_raw)", True, "v3 in CastedLinear"), + ("lyr._qat_active=True", True, "matrix markers"), + # Things that should be ABSENT: + ("_fake_quantize_int6", False, "v1 helper (gone)"), + ("_fake_quantize_per_row_gptq", False, "v2 helper (gone)"), + ("F.embedding(input_ids, _qat_w_emb)", False, "v2 tok_emb wrap (gone)"), + ] + all_ok = True + for marker, expected_present, desc in checks: + actual = marker in final + ok = actual == expected_present + sym = "OK " if ok else "BAD" + print(f" [{sym}] {desc}: {'present' if actual else 'absent'}") + if not ok: + all_ok = False + + if not all_ok: + print("\nFAILED — review above") + sys.exit(1) + + print(f"\nDONE. File: {FILE} ({len(final):,} bytes)") + print("Train loop must call: _qat_set_step(step) at the top of each step") + + +if __name__ == "__main__": + main() diff --git a/records/track_non_record_16mb/2026-04-18_EvalTimeAblations_SP8192/patches/patch_swa.py b/records/track_non_record_16mb/2026-04-18_EvalTimeAblations_SP8192/patches/patch_swa.py new file mode 100644 index 0000000000..086547269b --- /dev/null +++ b/records/track_non_record_16mb/2026-04-18_EvalTimeAblations_SP8192/patches/patch_swa.py @@ -0,0 +1,43 @@ +"""Add Sliding-Window Attention via SWA_WINDOW_SIZE env var. + +When SWA_WINDOW_SIZE > 0, each query attends to last W tokens (inclusive of self). +When SWA_WINDOW_SIZE = 0 (default), full causal attention (baseline). + +Modifies a single line in the CausalSelfAttention forward call to pass +window_size to flash_attn_3_func. +""" +import re, os +F = 'train_gpt_stacked_v2_fixed.py' +src = open(F).read() + +# Add SWA_WINDOW_SIZE module-level constant after the Hyperparameters line +# Keep it simple: module global read from env +old_marker = "from flash_attn_interface import flash_attn_func as flash_attn_3_func" +new_marker = ("from flash_attn_interface import flash_attn_func as flash_attn_3_func\n" + "_SWA_WINDOW_SIZE = int(os.environ.get('SWA_WINDOW_SIZE', '0'))\n" + "def _swa_window_arg():\n" + "\treturn (-1, -1) if _SWA_WINDOW_SIZE <= 0 else (_SWA_WINDOW_SIZE - 1, 0)") + +if "_SWA_WINDOW_SIZE" not in src: + src = src.replace(old_marker, new_marker, 1) + print("Added SWA_WINDOW_SIZE global") +else: + print("SWA_WINDOW_SIZE already present") + +# Modify the flash_attn_3_func call to pass window_size +old_call = "y=flash_attn_3_func(q,k,v,causal=True)" +new_call = "y=flash_attn_3_func(q,k,v,causal=True,window_size=_swa_window_arg())" +if old_call in src: + src = src.replace(old_call, new_call, 1) + print("Modified flash_attn_3_func call to pass window_size") +else: + if new_call in src: + print("Already patched") + else: + print("ERR: target call pattern not found") + raise SystemExit(1) + +open(F, 'w').write(src) +import py_compile +py_compile.compile(F, doraise=True) +print(f"Patched: {len(src)} bytes, syntax OK") diff --git a/records/track_non_record_16mb/2026-04-18_EvalTimeAblations_SP8192/submission.json b/records/track_non_record_16mb/2026-04-18_EvalTimeAblations_SP8192/submission.json new file mode 100644 index 0000000000..3e777b0280 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-18_EvalTimeAblations_SP8192/submission.json @@ -0,0 +1,75 @@ +{ + "author": "himanshudongre", + "github_id": "himanshudongre", + "name": "Ablation Report: Eval-time lever analysis on SP8192 absolute-RoPE stack", + "date": "2026-04-18", + "track": "non_record_16mb", + "related_record_pr": "https://github.com/openai/parameter-golf/pull/1716", + "summary": "Structured ablations behind the 2026-04-18 SP8192 + BigramHash d=32 + Path A v3 record (val_bpb 1.07882, 3-seed mean). Documents what was tested, what worked (d=32 bigram, Path A v3 aggressive passthrough quantization), what was killed (TTT_EPOCHS=4 saturated, EVAL_SEQ_LEN=4096 OOD on sliding, QAT v3 x score-first-TTT catastrophic interaction, Adaptive Hadamard null on Muon weights), and the architectural reasons behind each null result. Primary contribution: documentation and mechanism analysis for the design space around eval-time compute on absolute-position RoPE models.", + "hardware": "8xH100 80GB SXM", + "pytorch_version": "2.9.1+cu128", + "methodology": { + "protocol": "single-seed probe first, 3-seed validation only on 2-sigma winners", + "seed_for_probes": 42, + "baseline_reference": "record PR #1716 seed 42 (TTT val_bpb 1.07886574, total artifact 15,991,203 B)" + }, + "findings": { + "confirmed_positive": { + "bigram_dim_32": { + "vs_d48": "pre-quant -0.0002 bpb, consistent across retrains", + "artifact_side_effect": "bigram.proj shrinks from 24576 to 16384 params, enabling Path A v3 int8 threshold" + }, + "path_a_v3_passthrough_quant": { + "mechanism": "per-tensor int8 on 5 control-tensor families + per-row int8 on 3 small 2D matrices + LZMA code wrapper", + "artifact_savings": "~40 KB at 0 bpb cost (5 d.p.)", + "bpb_roundtrip_cost": "0.00000 measurable (identical to baseline at 5 decimal places)" + } + }, + "confirmed_null": { + "ttt_epochs_4": { + "delta_bpb": -8.96e-5, + "sigma_multiple": 0.06, + "verdict": "saturated" + }, + "eval_seq_len_4096_stride_128": { + "pre_quant_delta": -0.00509, + "quant_delta": -0.00334, + "sliding_delta": 0.00555, + "ttt_delta": 0.00033, + "verdict": "OOD query-phase tax at sliding-tail positions dominates context gain", + "architectural_implication": "requires training-time change (ALiBi/SWA/longer train_seq_len) to unlock" + }, + "qat_v3_matrices_only": { + "pre_quant_drift": 0.015, + "ttt_result": 1.48169, + "verdict": "catastrophic QAT x score-first-TTT interaction", + "mechanism": "SGD during TTT pushes weights off-lattice; QAT-trained model highly sensitive to this" + }, + "adaptive_hadamard_gptq_on_muon": { + "verdict": "no MSE reduction; Muon produces sub-Gaussian per-row weights that are already smooth", + "literature_note": "SpinQuant/QuaRot benefits do not transfer to Muon-trained small models" + } + } + }, + "tokenization_advisory": { + "issue": "SP8192 shards found on some RunPod volumes are tokenized with a different SP8192 BPE model than the one in the canonical tokenizer file, causing a ~5.6% val_bpb-scale shift", + "fingerprint": "random-init step-0 val_bpb should be ~3.487 on canonical data; numbers far from this indicate a tokenization mismatch", + "resolution": "regenerate tokenizer + shards via download_hf_docs_and_tokenize.py --variant sp8192" + }, + "compliance": { + "issue_1017_track_a_permitted": "eval-time sliding-window patterns, KV-cache strategies, inference optimizations without updating model state", + "issue_1017_track_b_permitted": "score-first TTT (score chunk, then train on it)", + "no_slot": true, + "no_etlb": true, + "no_ngram_cache": true, + "no_pre_quant_ttt": true, + "no_condition_violations_in_any_probe": true + }, + "future_directions_ranked": [ + "SWA training (flash_attn_3 window_size) to unlock eval-time KV chain", + "TRAIN_SEQ_LEN=4096 direct training", + "Position-dropout at training time", + "Legitimate FLA / GatedDeltaNet with rigorous byte-accounting", + "Per-document LoRA TTT over score-first legal TTT" + ] +}