diff --git a/records/track_10min_16mb/2026-04-18_SP8192_BigramHash32_PathAv3/README.md b/records/track_10min_16mb/2026-04-18_SP8192_BigramHash32_PathAv3/README.md new file mode 100644 index 0000000000..b3d47634ac --- /dev/null +++ b/records/track_10min_16mb/2026-04-18_SP8192_BigramHash32_PathAv3/README.md @@ -0,0 +1,202 @@ +# Record: SP8192 + BigramHash d=32 + Path A v3 Aggressive Passthrough Quantization — val_bpb 1.07882 (3-seed mean) + +**val_bpb = 1.07882** (3-seed mean, std 0.000143) | **mean 15,993,825 B (15.99 MB)** | 8× H100 80GB SXM | Legal Score-First TTT + +Beats the merged SOTA ([2026-04-09 SP8192 record by @bigbag](https://github.com/openai/parameter-golf/blob/main/records/track_10min_16mb/2026-04-09_SP8192_3LayerRecur_ParResid_QK525_LegalTTT/README.md), 3-seed mean 1.08100) by **−0.00218 bpb / −0.00564 nats per token** on a 3-seed mean, clearing the 0.005-nat record threshold with one-sided **z = −3.00, p = 0.00136** (p < 0.01 required). + +## 3-Seed Results (8× H100 80GB SXM, PyTorch 2.9.1+cu128, Legal Score-First TTT) + +### Core (TTT) table + +| Seed | Steps | Pre-TTT sliding bpb | **Post-TTT bpb** | TTT gain | TTT time | Artifact (B) | +|---:|---:|---:|---:|---:|---:|---:| +| 42 | 4393 | 1.08015 | **1.07887** | −0.00128 | 336.1 s | 15,991,203 | +| 314 | 4393 | 1.08024 | **1.07893** | −0.00131 | 335.5 s | 15,994,170 | +| 999 | 4403 | 1.07998 | **1.07866** | −0.00132 | 333.6 s | 15,996,103 | +| **mean** | | **1.08012** | **1.07882** | **−0.00130** | **335.1 s** | **15,993,825** | +| **std** | | | **0.000143** | | | | + +### Diagnostics + +| Seed | Post-EMA bpb | Quant roundtrip bpb | Sliding bpb | TTT val_loss (nats) | Code bytes | Total submission (B) | Train ms | Eval ms (q+sl+ttt) | +|---:|---:|---:|---:|---:|---:|---:|---:|---:| +| 42 | 1.08584 | 1.09678 | 1.08015 | 2.78662485 | 18,097 | 15,991,203 | 588,110 | 480,408 | +| 314 | 1.08580 | 1.09679 | 1.08024 | 2.78678778 | 18,097 | 15,994,170 | 588,031 | 479,495 | +| 999 | 1.08561 | 1.09662 | 1.07998 | 2.78608265 | 18,097 | 15,996,103 | 588,029 | 477,724 | +| **mean** | **1.08575** | **1.09673** | **1.08012** | **2.78650** | — | **15,993,825** | **588,057** | **479,209** | + +## Key Innovation: Path A v3 Aggressive Passthrough Quantization + +Two complementary changes on top of the [2026-04-09 SP8192 stack](../2026-04-09_SP8192_3LayerRecur_ParResid_QK525_LegalTTT/README.md): + +### 1. `BIGRAM_DIM = 32` + +BigramHashEmbedding dimension reduced from the common d=48/64 to **d=32**. Smaller bigram projection regularizes the hashed n-gram signal and frees ~262 KB of raw bigram parameters (compressed ~3 KB, modest on size but also lets `bigram.proj` be even smaller for the Path A v3 int8 treatment). Pre-quant post-EMA is preserved at ~1.0858, within noise of the d=48 baseline. + +### 2. Path A v3 Aggressive Passthrough Quantization (primary contribution) + +The canonical [PR #1394](https://github.com/openai/parameter-golf/pull/1394) / bigbag stack leaves the following tensors as **fp16 passthrough** (1 tensor per transformer block layer plus a few scalars), consuming ~40 KB in the compressed artifact: + +- **Control tensors (per-tensor int8)**: `attn_scale`, `mlp_scale`, `resid_mix`, `skip_gates`, `skip_weights`. Each is a small 1-D array with a narrow dynamic range. Quantized to int8 with a single fp32 per-tensor scale — reconstruction error dominated by scale quantization is negligible (< 1e-4 relative on all tensors). +- **Small 2-D matrices (per-row int8)**: `bigram.proj` (512 × 32 = 16 K params), `attn_gate_proj`, `smear_gate.weight`. These are dense but small and excluded from Hessian-aware GPTQ by the `numel() <= 65536` threshold. Quantized to int8 with per-row fp16 scales. +- **`gptq_mixed_quantize`** and **`dequantize_mixed`** in the submitted training script are modified to dispatch these categories before falling back to fp16 passthrough. Everything else (int6 attn/MLP matrices, int8 tok_emb, int6 bigram.embed) is unchanged. +- **LZMA self-extracting wrapper** over a python-minified source: 53,508 raw → 52,775 minified → 14,384 LZMA → **18,097 bytes** wrapped. (Same technique as @bigbag's record.) + +**Quantization quality cost:** measured to 5 d.p., the quantized roundtrip bpb is **unchanged** between baseline and Path A v3 (1.0968 in both). The Path A v3 modifications are effectively zero-cost in BPB while saving ~40 KB on the artifact. + +**Net size effect:** total submission averages 15,993,825 B across 3 seeds (6,175 B under the 16,000,000 cap). Prior SP8192 runs without Path A v3 at the same `EMBED_BITS=8` configuration sit at ~16,065 KB (~65 KB over). Path A v3 is what makes int8 token-embeddings legal for this architecture. + +## Architecture + +11L × 512d × 8H / 4KV, MLP 4×, LeakyReLU(0.5)² activation, Partial RoPE (16 / 64 dims), tied token embeddings, logit softcap = 30. Skip gates (sigmoid-gated U-Net connections). Depth recurrence: encoder `[0,1,2,3,4,5,3,4]`, decoder `[5,3,4,5,6,7,8,9,10]` (loops layers 3–5, activated at step ~1950 = 45% training). Parallel residuals from layer 7. **BigramHashEmbedding with 16,384 buckets × d=32**. AttnOutputGate (width 12, source=proj). SmearGate (width 12). SentencePiece-BPE 8192. + +## Training + +MuonEq-R (row-normalized Muon, Newton-Schulz 5 steps) for matrices; AdamW for embeddings/scalars. Warmdown 72% of training; EMA decay 0.9965. QK-Gain init 5.0 (learnable per-head). Weight decay 0.085 / 0.095 (embed / matrix). ~4393–4403 steps in 588 s on 8× H100 SXM (`MAX_WALLCLOCK_SECONDS=600` minus 12 s GPTQ reserve). + +## Quantization + +Full-Hessian GPTQ with SDClip (`clip = k × std(row)`): + +- **Matrices** (attn/MLP): int6, `matrix_clip_sigmas = 12.85` (@clarkkev PR #1394) +- **Token embeddings**: int8, `embed_clip_sigmas = 20.0` +- **bigram.embed**: int6 per-row simple scale +- **Path A v3 additions** (this PR): per-tensor int8 for control scalars, per-row int8 for small 2-D matrices (see Key Innovation section) + +Byte-shuffle + Brotli-11 on the quantized state-dict blob. Self-extracting LZMA wrapper on the minified source. + +## Test-Time Training (Score-First, Legal) + +Per [Issue #1017](https://github.com/openai/parameter-golf/issues/1017) / [PR #549](https://github.com/openai/parameter-golf/pull/549) / [PR #461](https://github.com/openai/parameter-golf/pull/461) precedent: + +```python +for chunk_idx, chunk_windows in enumerate(chunks): + # Phase 1: SCORE (under no_grad, no parameter update) + with torch.inference_mode(): + nll = model.forward_logits(batch).cross_entropy(targets) + loss_sum += nll.sum() + + # Phase 2: TRAIN (only on the chunk just scored) + if not is_last_chunk: + for _ in range(ttt_epochs): # 3 epochs + for x, y in chunk_seqs: + loss = model(x, y) + loss.backward() + optimizer.step() # SGD, lr=0.005, momentum=0.9 +``` + +1,238 chunks × 32,768 tokens × 3 epochs. Strict score-before-update ordering; no token is ever trained on before it is scored. Mean TTT time 335 s per seed (well within 600 s eval budget). + +## Rule Compliance + +Per [repo README](../../../README.md) and [Issue #1017](https://github.com/openai/parameter-golf/issues/1017): + +- **Condition 1 — Causality** ✅ Strictly causal forward pass. Sliding-window eval never references future tokens for current-position scoring. +- **Condition 2 — Normalized distribution** ✅ Standard softmax over full 8,192 vocab. No n-gram cache, no logit biasing, no multi-pass rescoring. +- **Condition 3 — Score before update** ✅ Every TTT chunk is scored under `inference_mode()` before any parameter update. Gradient updates only use already-scored tokens. +- **Condition 4 — Single pass** ✅ Each val token is scored exactly once. No rescoring, no cache lookups. + +Additional: +- **No SLOT** (standard or causal) — no eval-time delta optimization +- **No pre-quant TTT** on val data — model is quantized once; TTT adapts the quantized model at eval time only +- **No ETLB** (eval-time logit bias) +- **No n-gram cache** or tilt +- **Seed choice conventional** — matches @bigbag 2026-04-09 exactly (42, 314, 999); no seed brute-forcing +- **Artifact < 16,000,000 bytes** on all 3 seeds (margins: 8,797 / 5,830 / 3,897 B) +- **Training ≤ 600 s** on all 3 seeds (588,029–588,110 ms actual) +- **Eval ≤ 600 s** on all 3 seeds (quantized + sliding + TTT = 477,724–480,408 ms) + +## Statistical Evidence + +Three independent seeds on a canonical 128-shard sp8192 tokenization of the `willdepueoai/parameter-golf` fineweb export: + +``` +Seed 42: val_bpb = 1.07886574, val_loss = 2.78662485 nats/token, total_bytes = 15,991,203, train_time_ms = 588,110 +Seed 314: val_bpb = 1.07892882, val_loss = 2.78678778 nats/token, total_bytes = 15,994,170, train_time_ms = 588,031 +Seed 999: val_bpb = 1.07865582, val_loss = 2.78608265 nats/token, total_bytes = 15,996,103, train_time_ms = 588,029 + +Mean bpb = 1.07881679 +Std bpb = 0.000143 (sample, n=3, n-1=2) +SEM bpb = 0.0000826 +Mean val_loss = 2.78649843 nats/token +bpb / val_loss ratio = 0.387159 (per-pod byte-count mapping) + +Merged SOTA (bigbag 2026-04-09 3-seed mean) = 1.08100 bpb +Observed delta = 0.00218 bpb = 0.00564 nats/token (> 0.005-nat threshold) +Threshold in bpb at our ratio = 0.001936 bpb +Mean bpb required to clear threshold = 1.079064 +Our mean bpb = 1.078817 +Margin past threshold = 0.000247 bpb = 0.000637 nats/token + +One-sided z (lower tail) = (1.078817 − 1.079064) / 0.0000826 = −2.998 +One-sided p-value = 0.00136 +Required: p < 0.01 → CLEARED +``` + +## Environment + +``` +torch 2.9.1+cu128 +CUDA 12.8 +NVIDIA driver 575.57.08 +brotli 1.2.0 +sentencepiece 0.2.1 +python-minifier (latest) +NVIDIA H100 80 GB HBM3 SXM × 8 with NVLink (18 links × 26.562 GB/s) +NCCL all-reduce 256 MB: ~424 GB/s bus bandwidth (near-peak NVLink4) +``` + +## Reproduction + +```bash +# 1. Install deps +pip install --break-system-packages brotli python-minifier sentencepiece huggingface_hub + +# 2. Clone competition repo + generate canonical sp8192 data +git clone https://github.com/openai/parameter-golf.git repo +cd repo + +cat > data/tokenizer_specs_sp8192.json <<'EOF' +{"tokenizers":[{"name":"sp_bpe_8192","dataset_suffix":"sp8192","vocab_size":8192}]} +EOF + +python3 data/download_hf_docs_and_tokenize.py \ + --repo-id willdepueoai/parameter-golf \ + --remote-root datasets \ + --output-root ./data \ + --tokenizer-config data/tokenizer_specs_sp8192.json \ + --skip-byte \ + --chunk-tokens 100000000 \ + --tokenizer-train-docs 1000000 + +# 3. Run 3 seeds +for SEED in 42 314 999; do + SEED=$SEED DATA_DIR=./data/ RUN_ID=seed${SEED} \ + ITERATIONS=20000 MAX_WALLCLOCK_SECONDS=600 \ + TTT_ENABLED=1 SLIDING_WINDOW_ENABLED=1 VAL_LOSS_EVERY=4000 \ + BIGRAM_VOCAB_SIZE=16384 BIGRAM_DIM=32 \ + GATE_ATTN_OUT=1 GATE_WIDTH=12 GATE_ATTN_SRC=proj \ + SMEAR_GATE=1 SMEAR_GATE_WIDTH=12 \ + EMBED_BITS=8 EMBED_CLIP_SIGMAS=20.0 COMPRESSOR=brotli \ + torchrun --standalone --nproc_per_node=8 train_gpt.py \ + 2>&1 | tee logs/train_seed${SEED}.log +done +``` + +The provided `train_gpt.py` is an 18,097-byte LZMA self-extracting wrapper. The equivalent full source (53,508 B) is `train_gpt_stacked_v2_fixed.py` for review. + +## Credits + +- **@clarkkev** — PR #1394: SP8192 base stack + GPTQ SDClip + int6 matrices / int8 embeddings + MuonEq-R + SP8192 tokenizer recipe. +- **@bigbag** — 2026-04-09 SP8192 record: 3-layer depth recurrence + parallel residuals + QK-Gain 5.25 + legal TTT on the SP8192 stack. (Direct ancestor of this submission.) +- **@dexhunter** — PR #1331, #1437: 3-layer depth recurrence; PR #1413: legal TTT on SP8192. +- **@Robby955** — PR #1412: parallel residuals on SP8192. **@msisovic** — PR #1204: parallel residuals concept. +- **@Christopher-Lee-McClendon** — PR #461: legal score-first TTT framework. **@abaybektursun** — PR #549: merged precedent for legal TTT. +- **@MarioPaerle** — PR #1667: AttnOutputGate used in this architecture. + +## Our contribution + +Two modifications on top of the @bigbag / @clarkkev SP8192 lineage: + +1. **Path A v3 aggressive passthrough quantization** in `gptq_mixed_quantize` and `dequantize_mixed` — per-tensor int8 for five control-tensor families (`attn_scale`, `mlp_scale`, `resid_mix`, `skip_gates`, `skip_weights`) and per-row int8 for three small 2-D matrices (`bigram.proj`, `attn_gate_proj`, `smear_gate.weight`). Net effect: the full bigbag-style int8 token-embedding + int6 matrix recipe now fits ≤ 16 MB with ~6 KB margin, preserving the full TTT BPB of the baseline. +2. **BigramHashEmbedding `d = 32`** (vs common d=48 / d=64 in the lineage) — modest regularization + complementary size savings that free a few KB for Path A v3 to work with. diff --git a/records/track_10min_16mb/2026-04-18_SP8192_BigramHash32_PathAv3/submission.json b/records/track_10min_16mb/2026-04-18_SP8192_BigramHash32_PathAv3/submission.json new file mode 100644 index 0000000000..505811e680 --- /dev/null +++ b/records/track_10min_16mb/2026-04-18_SP8192_BigramHash32_PathAv3/submission.json @@ -0,0 +1,100 @@ +{ + "author": "himanshudongre", + "github_id": "himanshudongre", + "name": "SP8192 + BigramHash d=32 + AttnOutputGate + SmearGate + Path A v3 Aggressive Passthrough Quantization + Legal Score-First TTT", + "date": "2026-04-18", + "track": "10min_16mb", + "val_bpb": 1.07882, + "val_bpb_std": 0.000143, + "seeds": [ + 42, + 314, + 999 + ], + "seed_results": { + "42": { + "val_bpb": 1.07886574, + "val_loss": 2.78662485, + "sliding_val_bpb": 1.08014601, + "quantized_val_bpb": 1.09678081, + "pre_quant_post_ema_val_bpb": 1.08584188, + "artifact_bytes": 15991203, + "train_time_ms": 588110, + "ttt_time_ms": 336109, + "sliding_time_ms": 120470, + "quantized_time_ms": 23829, + "eval_time_ms_total": 480408 + }, + "314": { + "val_bpb": 1.07892882, + "val_loss": 2.78678778, + "sliding_val_bpb": 1.08023616, + "quantized_val_bpb": 1.09679203, + "pre_quant_post_ema_val_bpb": 1.08579894, + "artifact_bytes": 15994170, + "train_time_ms": 588031, + "ttt_time_ms": 335468, + "sliding_time_ms": 119913, + "quantized_time_ms": 24114, + "eval_time_ms_total": 479495 + }, + "999": { + "val_bpb": 1.07865582, + "val_loss": 2.78608265, + "sliding_val_bpb": 1.07998003, + "quantized_val_bpb": 1.09662297, + "pre_quant_post_ema_val_bpb": 1.08561033, + "artifact_bytes": 15996103, + "train_time_ms": 588029, + "ttt_time_ms": 333575, + "sliding_time_ms": 120053, + "quantized_time_ms": 24096, + "eval_time_ms_total": 477724 + } + }, + "mean_bpb": 1.07881679, + "std_bpb": 0.00014293, + "sliding_bpb_mean": 1.08012073, + "quantized_bpb_mean": 1.09673194, + "pre_quant_post_ema_bpb_mean": 1.08575038, + "mean_val_loss_nats": 2.78649843, + "vs_merged_sota": { + "merged_sota_bpb": 1.081, + "delta_bpb": 0.002183, + "delta_nats_per_token": 0.005639, + "threshold_nats": 0.005, + "threshold_bpb_at_our_ratio": 0.001936, + "mean_bpb_required_to_clear_threshold": 1.079064, + "one_sided_z_statistic": -2.9982, + "one_sided_p_value": 0.001358, + "cleared_threshold_at_p_lt_0_01": true, + "merged_sota_source": "records/track_10min_16mb/2026-04-09_SP8192_3LayerRecur_ParResid_QK525_LegalTTT/ (by @bigbag)" + }, + "tokenizer": "SentencePiece BPE 8192 (trained from 1M canonical fineweb docs)", + "architecture": "11L/512d/8H/4KV, MLP 4x, LeakyReLU(0.5)^2, Partial RoPE 16d, Depth-recurrence (loop layers 3-5 from frac=0.35), Parallel residuals (layer 7+), QK-Gain 5.0, Skip gates, AttnOutputGate (width 12), SmearGate (width 12), BigramHashEmbedding (16384 buckets x d=32), tied embeddings, logit softcap 30", + "platform": "RunPod 8xH100 80GB SXM, PyTorch 2.9.1+cu128, CUDA 12.8", + "hardware": "8xH100 80GB SXM", + "pytorch_version": "2.9.1+cu128", + "technique_summary": "SP8192 + BigramHash d=32 + AttnOutputGate + SmearGate + Depth recurrence (L3-5) + Parallel residuals (L7+) + QK-Gain 5.0 + MuonEq-R + SDClip GPTQ + Path A v3 aggressive int8 passthrough (control tensors + small matrices) + Legal Score-First TTT + LZMA code pack + Brotli", + "compliance": { + "artifact_under_16mb": true, + "training_under_600s": true, + "eval_under_600s": true, + "no_slot": true, + "no_pre_quant_ttt": true, + "no_etlb": true, + "no_ngram_cache": true, + "score_first_ttt": true, + "three_seeds": true + }, + "attribution": { + "sp8192_gptq_embeds_sdclip_muoneqr_depth_recur": "@clarkkev (PR #1394)", + "three_layer_depth_recurrence": "@dexhunter (PR #1331, #1437)", + "parallel_residuals": "@Robby955 (PR #1412), @msisovic (PR #1204)", + "qk_gain": "@clarkkev (PR #1394 default 4.0; raised to 5.0 as a tune on top)", + "legal_ttt_framework": "@Christopher-Lee-McClendon (PR #461), @abaybektursun (PR #549), @dexhunter (PR #1413)", + "attn_output_gate": "@MarioPaerle (PR #1667)", + "bigbag_base_sp8192_record": "@bigbag, records/track_10min_16mb/2026-04-09_SP8192_3LayerRecur_ParResid_QK525_LegalTTT/ (3-seed mean 1.08100)" + }, + "our_contribution": "Two changes on top of @bigbag 2026-04-09's SP8192 stack: (1) BigramHashEmbedding dimension d=32 (vs 48); (2) Path A v3 aggressive passthrough quantization (int8 per-tensor for control scalars + int8 per-row for small 2-D matrices that were fp16 passthrough) + LZMA self-extracting code wrapper. Net artifact savings ~40 KB + 35 KB, enabling int8 token embeddings to fit under 16 MB with ~6 KB margin while preserving the full BPB of the baseline recipe. 3-seed mean 1.07882 (std 0.000143) clears the 0.005-nat threshold at p=0.0018." +} \ No newline at end of file diff --git a/records/track_10min_16mb/2026-04-18_SP8192_BigramHash32_PathAv3/train_gpt.py b/records/track_10min_16mb/2026-04-18_SP8192_BigramHash32_PathAv3/train_gpt.py new file mode 100644 index 0000000000..b9b22573e4 --- /dev/null +++ b/records/track_10min_16mb/2026-04-18_SP8192_BigramHash32_PathAv3/train_gpt.py @@ -0,0 +1,2 @@ +import lzma as L,base64 as B +exec(L.decompress(B.b85decode(";La>KC|v+An1iBmm2Lj?#GMy3nc5FDC5OrW_PApDM>V>ui&z$DDoYt&HOr-YcRavx&eXRn?afQA6n6r443e~?FX!ZzOMj?myitmTW#fAmj~o95;!EoIpb)B3z)?SDn(z(jZNK2Lv%Kg3CRay6djTaM(n!-Kln4?8Saqqe`k1D+0tTiZSpwGH`AF=~;-A%RiNINklX#eRh0i^&0g_vdkRAe11!m79HNIZl3>gS*uf2?0m#$sX+#MNls2y9uN?PO%kZpic_XMV&=Sswx+ehbpKr@4iJa1^b_5q~b-ylubG)v=uk7&j7xneAQLgaALOynbb_V|VwNU{7xJf>?j(130APvN?uNX!_@gh&Gf%lSLQVhBtDC#Xe+WeCB-mxYhdUpXe?)(j-jPk=Vc5W|Fm1&vH2kp67Av^gPmX(fWE-hkgLLeQHvT#HV?Razu;F~6nMnA?03dH*@MV&&+4?hG3fySZ|wEXw}5O}DW5XXB>gRth3On0nj-$94oUerULMS`su-+MXUU;qJxcLAvM7J5}i;7s{>N$unEeRfeOiCfi~>3TpN?HlqzYmfe7h^c^X++|8ZU>0R#VFv#iW~)6pyQ!1s3n?%|=;t?9nGMverb1O?N!dr|B*Yb&THp=yUL7Bg@28^)A|Hmovr%SmZ9f2fGA7whvTyh|QEDF&2<6}{{m+H51_`T`r+7pcKiw%;T!udJt8kydax)BLb9I!g(0~>{G@p~8uKMUKVWo@VaXu#m$Z2;qXuL_eF4dQ4{QRo*TcRRc8zkG8qo07wFU1I@8Ds(_&UFl`9Y2CDOTEX$A)hz0(nfy)3`^aSP$Vq97+CimK?uP55fC-#eu)i;RYTk!%fbPDkmg?z^-O52l6*q*7K2aD6GLzxEEL0n8NRuR-LHxp$Sjj^J@?FHh(%!#F>SA=h)`I$)5|&V?QuWdM=7llsImLSdND;XsTQ$`uYsa*O{u)K>lyislXJy(?;_J|RL3qE%Q&*m10pZWg+-*mU2Oypu$7F}1wM&{V|7xg9ZuN+Z$@;WEn8zOgg;~HrD?}i90&-@wZ!`(j`Vf$!N-G#O_cjJdIjf;;>i7yz#)|4qx+V?KHg7UEKlVw!5V3>9ya^Lq!L54XFb@LC-=)C7gnFp_@b_6fwS0sO*;$z9jff~x%qIx+saZ4@$8*UBKne*N!&g#o*KXxkW+Q2`@kdN42!F9;AG33Nh#3sY=DqiPAa(xbv_d!*mI_q7AU#nu%r_*Sy4s|dl0&yp<^mskZef_qIgDh*Yr))2&lq$J=ySktS6>?%tk?8b<#fN+gM@mj$yAY@Fp|kN`FR6f|r*&?1hl^0CGq|a6F^GSh57?A4Si{+@W8|#F0oUG`YfUFxjH>s+S}vlmrO!MTslFM-B}H^vCV@7$?{D>CKHLMS1?z&jouJRu|rtNAu+4-q=(2^~E+M#cl+&V$IvnO1<(my?0WFQpmW;99mUY_#k=W*fk$7|H{ho*W+|JA#xJZU;ja9M#)?l=B(H+Jq-@!v97O8S!3NzK-?lIe*(q^Gm%5W5G{D?8Q&vG-tFpT51slE9B5Z~|$(**a`tgjp_++lQn&zf=e!~G72&%l$BqqQjVR|X4tV6hUYQv4RyHQyvXA|)vyiPCTPO(SU#%^5?7@7dW__luo2bA=Jb$bIOCrkq4K=gQ4JDmp(#RQ2fw+3M=S0Qsf{Mz9vw2g-1k$Zn*3w2c>ndMHsz*S@1U6Ywg>m}|WeiNq|IECi^e6~uPT(*0!3zVET0ZX7hPdU+!7PYO5wUMbgx=zs_OusB3AR(wG-2Lp(x>{xzfZ#1rmz7Z-fdeA54`!7F)AZ58`H!5R2yPZR#H;}mE?_-i$V7^&ZP18-1I7U$0PE)2FA8}R`rcXdJ5+y6lwC^1WRvt)>!B8hj+Dis36HN7hq2__e8R;fY`#QzlYetnd>S~Sa`FTvxA7F1OZl~<*fHi`Al$g~X`CEAAThOSsxo{*bl2P`QpvwlJ2(`eHhcNM7asOijRQRvewj=bn*Sc*)ZF)lQ4|G#StqKheXA630Eihd25XWycTwcL$=>r+trO21Sjr0HdcEm(A~PujTZ5hw1$9HG67C8W2%TEn0YzQ;OlJfak0m*0#O%PAwSQ?g!Up_2P-hdB3iEN$P%zB%j4)_c>;#~|Xu2unPj-?YAWnG^AKDJ9*pTios>*I<+W)lU$fM9E~eJZsIJ`!~9;_0Ic+WrJf)vxeU$JGbT`iePF;m-XBs_o(%YLXTY-9jyzP2mH~VhT}Aj-dECGLqiulVS(--?cC|Dv91t-PP=Ucerku{xSLm?!h@xWi>b*KGWvkLM@XutFn?o=qn?0KaglEDx5Y%bQwp3Sn!RwE;j#K3iFm_mW~Lr@Q_r4xwU+{{IP=%P3;97*DGoGi75rb@)D{(*SGRY;wZ!S@4Ppz@99U9@zj+PL6fJlW?NkDLd#oaIYP?URFEP5kx_ZAv50hONz1k^m{?O-eM5@8-P4_gAP)T9N$iDh+6*4*qs{~?S$P66;#K;n1T{nA0+EBM0@LysY1u5C3_6C*p?J*jk&t3f({A8+sk4lSgfQSY#!e$`Pm0tbmxGfpuaNoC6^B9>@c8(7!WMTX4`n>Ruf6(p|3gfYB(O^#9;`UZlL?gbOjyh#nCI{?JY&77$TKKqN1xq@g@1nrzu;;jDg-2RU!bggiWxPtyU4ktu#yt-vnF>+rh*DM}?=x`CILHl5O+QwKfl1dx@wK-xn}kHD(Lco0^H!{4U>mPmeSlu4ZU;0&dT6P~~YymS5c|i7MV_jmpJp2XpQmkdPFU`Vv2o|mR?rD?m;oOUuql&1on^Tq|Z+hFPp8+$s_#Pb*bx%WYH#XdblBOQ02H)ZvOT+!33YB5UXHYEs~dnGA%+ng60%n29l^nr-XHtq8;;M8C`VG4u^e(n!%YGbT1yU&mCrp+?$wq`Zi4sFLUM1(@QN-dgV$PYI@_H?RR7!VZdMX{m`7Y3nO-A*7{SI(ayVN#}>vSrxyw?WXP7d4S+Me_Z3iw0lF8w(}yC5E)$7eLVl+Ben{)Iw@iAG6C{=Ea9=pZ4f-OwxGnLf>Af-HGE(SUoTqRy82Mz-vw@}Edtm$khz&KM;;Xx!!z;>fQ)2kzrNulbo=!FWI5=Sn3PpV)|so$1^z$%kP3{I`G20_SDudTgdixy>V4L=pO!aLvAwd+W8}-r)lPb^!h#`*iaOW+J%d+?#1)?NvR&Z7vkf*qyC;Yg|h8%d@`5oZP$;3E_TBd%aPML=!;?>5|pr=XG$=a-SaVW15t^^83ecCUJ~fMxa+hm#7K|%9GIjOiP4;$ji$w<{&!n1v~YpnwgpYEuf^1PB7b3t(8!J_oKo_B=y6@ho+}#z>M!S8A&93M8WQ-isKPrQ7u(uX)<#D53`Xg%=^uKHKIr4Ba|Z#nAT{?nD|A0;QfwzQn|%w}KI#8=BFZaYW}D2L|s-v|inW~i-`_IxVbSQAxt&9|YWFPd7a1di}@y-jaOsR%J(9!+fZqsiu4CEJK<=pjDhUPfM>Na(Y@(sDVaMO>r^XcQp)!g^QGMaP=1Tw&O@l;vousL@U~1$aaean*y0I+H`^-E`I`{(>nle~F%Ik}(B8EmUsY90hKuXirACTm9I4cO$rw5LHeU4vKdgyo=ul_NSw3jX8XV34^`>(cF@k5q#a;$^hjR42W1W8%##TgkYE`$Nc*)a5Z?)1uRo0Pz94&F8tr8%C6i7OwhZ1=d4qmFvmd^R(|pHI~Lt=RvPU0O$m399K~5wEVHhYHIOvn;;TQc&=6ZrH7R*+y(%Y1>F(l2sYA5UCUkSKK1a9IP;grSRP_ryb{HTHqF~h`#f&96?zOD_NGn;5He{so;36HBlsC3fGDC0eA>~$QjAbJ7E$~TtvQ_Q1cC#k@?7VGejs8YV3wdqQ3hX5d~Uuz&uJHvg9g5PbSo&^5_V?W^wG>5A?QL0u6T8Z?&RkR^%&F71IYX#vORe{v*I8nQbDSV>*dx52lNn4-mZM0UcO7gW~j#g(O9ZvJI!!@*)q*#$-WRHLyYtpS_rY8P>GI_ynZyugT4J4EGq|5sz`$9-^L{8rpurQrXF;fM7f|$>p_PZ>!@YdK2ChwttBV1>?oXav5PZ*v-TJc^xaP`;OCr&Le?1Q{$&!(pwX+#z5EzD$UF(C<}~hmq`;#hq#7ZP+`4DbKjVwfE@Y1m)ZT~=`PU~v@^^p$=s`az#>>T%$ay-Mgz*+f(TIItRO&~Q!RCfHdQ&i*r%7_sq7@3-y_=1qexkJS8hv(sqRO4Z258lbR4}E^2Qdt7@_BSz+GawYHskgF*D*GkB0lW=A_HwH{x~*y(Ze{FKaoDaW>v4_E)7g*;wh~$D=39qfeS?T$bEmGONB_0To*v$|Ig*^xKZu_{kkweyUGi~sxQeJE!H*`df&9dhNYQW|y==N6|FR6KfyU}G19tdo0?bTWe(mtNc=w(jV!$>0?-$5=5d+2!73;&?j6q^RZW1U6OPov?}las`Ry-#KcklU<-*mUQO(FKkPJwR>o;x?SKZni284n|^#!b3Amy%6VcyFtS$n>}jJ`5Vqx*|s=dyocJRTPY`u4H0RrTj_P%)95cvx6wVg$fe{LrzeUzVbrE0=}HO?>snOgK{61;j8I8}Hhdw*$;vQFfO)e2_Aq9(BFwoy1+J3lXdZ_AyC$o~YptzBQRKL~Upy{(-P1W#ryq{T|k60|h-sURlyP-Uk2R6uC~JY@KJsRU#qEu>QN}TxXz6-Kz?7HBkw;FK=I+Te$T6N^5ti+-b&2RoVH3iS(scQJT*F;{~1ehB=o6#P_8+uly&^x=!Rh>M9%)`a&9|G4d!lf+mr)XrKZoUSCO@4kd#EK$B}QDMlF;EMO0ZjptX8Ey`zm_AQC}AuY~On;#;Zmap4$aN}=W_X(YF*j#l`R^gHueO?5PlBETxFxyy``C1qrBXVE}J?+GhnoF9}z9+#!GZTsj*jVfz-tFY9iZe`HF99{GqaL2maB-W3#TorRUQr0+{qylBlw&-D9jK|zjpx%`LXB*QKek@Fq{O@LS`ViLc*y-HYiw|gk2Gn~;|5%gpoKk9Gv{UQ@4BqCSfm8;MK9+sI~}rIlh~NEOE?6zXtBcAwIgWZ<-@N!1otiSTS2hiBrO*>UX`C&_TT-zTAE%!kdwE1T-S;PkDj=JW_-Z8B;yAWp&r>_;BA;zgivQyIEw^67da=9f-PmfQs@P}Js&T~iDLp*2R}+|=J(<8O}A7LOZ55Z!GumOUbk@Z|~3($iMO6nnOeo>LG}bo>GyFWn&#U-W&IOtM`x>8sR#d+n);kYpiQNg(s%--(6c%s-4OMs>u!imdnCF*WN-2Jz<}k?>@3HOnOJsPZQZa>EBMp2e32Q>e-A!Tz*NA6A>g%toIn|YF*f-ukor$Kp>B=A>;@T5G>P}n=|xk;)q*WuJSZTD(CupCwPvPr#9Ziw6V0GR;|61(ey>@|gTrrF(UKwIeBlx%Abyem_PF$XjR(DEse`bLEE-D2dUq<8K2U*5*Rn_Eu2vz4y*|l??-v-WR!lta2}~fmqdXqy_|k*@i4{pfEq%Q*}Gk^#j?HEs|fQX*zQ!zp6|Ly)abD&W3n#Q5gipX<}QbUe@lUGE?qP4n-gf)#>hI(i8&?kUhcO7U?T_(>W|;YyH~{EAw&Q=;W)tkKg!v(A%~c1h{*p2sKRFs_cb*KKpg+kPS8Sma61~7?ZeeY*iIqR@3e6oJ0$AZ?hg1Q#zIG??{&hn+pi5K$4R=X(Rj33q8aznt51qg>X^eaRCX4A<#=$5c?KZapYGfo5T&ZKmlq0`{A^^)6fRg#Z!AMu+0Nz$fuSXXhTd|{%+>~8#yK_Q!m;C^9#8@0(cXQnv&Hs6j!W&vy)9|<9VuNBP`5|9pKpO8k6J8_xsX}AJyXJX8E$c(`~2q91B*HD$y4!{M&?pf6m?yoIFSbOQtPkw5(gn#$D#iZ%lY15sF+QZ4jFwJS)w&8;&KqdU6&ssyQVhybGUIesR&L`4wQ|{J5@Sf75Mc>?vUm5m-uS&O)(;8aeY@1f$QjvI9Mc|BXc$o4mmy9Z(dIFm)MB6NuEHG+Z?D8J<(Y6yLL2kWrkxGEyTthb~6zui14gg3AI4Sn9%Bo&26HW!$#`;?KrLKv44dJk)Qlf!5MFM%i^=&(1F3R~9w}v212vt6rVYp@FKgEr0@vO12=e9)$7*`UmO|VaO?ZzkKQ<-g3RW`e(2s3dPhms;>gr>$U`e90_f&ar~Z;cVn>wE-v=Igpu$Yr>Cd{oBUV%qWUaf3e)xHFvAqs}C*g+6XL_c-=Q^hlLv+spvLjxe84zqtQfU74+=?40iCoat&=(0A$sR{@elxrc*f8k$8WxWnDoTYK%z&XJ!z_m34RKORvK6o;;GRHpGCOI=yyUf@`<793B;h8BB!Dud{|Ciq9oWl{TZng4vh}cHlo@wE%&j+!vFuRL7CY0T*?b-58%}>aut^^43YJt#(Vv++Z{8vCiHmGX_EgwSwozSR7}m7qefenEy`08{=)0}mWDc8k=-sozYwzjP>?}ZrU#(OWfUZJe!g3x2Tj+?<0NL#e5bN4HpA(eSqvx*w#&=HL7FLe}kdtaS_~dYd@byfK`48W^^iPYX>qyD+&@-Ax%Hm(IufDaNM137$6Nmc5O4+1y3lF5yyvvJKdx6wqY<+Tbi(zG9;>)l@yHw=t)OjN`iEr!kB)ev`UG%aDJiFHXmF2U*aOy12R$_mQ@xLE4!^Q>r4uNJbt>}6SPF}K{1FRwcC?N4Q5p&Py8lKP0jA>Ux(z1kOC98Rj)FR#cXxrF@zEwOHO`12G>^(W!{uVmGH4I5gLSKw1I$V=Cvp~vVfC&|G8S!f}VKFoDosA<-w36Jj1@*AJuS-lXih5c5x>(&5~4_oxI57XBBvDn4O5{C(u;G^!W!MdQW?7BtoSuAZ!-<4$El~n2rs$vU@m+;MPFo%;hb}uklyNh3ce$lO5ZONP#dtkiG26mq^sCZf7EwjxQFJ(ZduFYb5e(d&(?gDtJX*LWp-40zC8593&0>$ZaKv)F`HKWiBh0rY{*(pd4{!k#E5{cC^4{lp$-xeHd>?%pAxa=;UpNK>GTZ`3yT6Dd9DM=d1gEAKQ~%I?rDifSI#{t3W9TCnWaZauga`4yA&sxr_;?Be*5`u7oanrTXDZnvam6`MajD%tSnt|gYEURu?XdH8oBo&s0+3Nc~?pQvCq?KSL*9s27_fdSv2*swCA6RXNLUs=e6=%wwURWrm;S9WT`EkJ|WWY)zUQ_;Qg?FB0&CzR(V!xOn;lE|$@OU>wpBWB`>DC?6DgWDAm)xMU;N6w|Dtwt+2sIyC8{QA>rJeD9SfiGdlX93qbJE9Fp$=N2*ynJv=IZvIN>Fp%JHPS(4>OmC`=Oir?qzbP$}PuChoyuGqPSN_ZiG?`m)8TODMYi04i$BQaXBX8M`uz1r(P)56phLQ<)|9OKm{j7$=JjaO-A|~@p!e#MzdMN00_;BO%O+4&p6*_x+o~K&MUKhfK;s*QlZ8+KtM{Y7=z2ksJkhK4OeYG~ENH%9|o*?nX*)!fg-DeBI?O)GCJPIo3py?m(&ldfKV>KZosbUkgTJ3q0$&H$M(D{4lCc!(uiOwUy-KiFjbX!wh+Qk4y?Ab`!TA|5v@kL>jN=hBw3IyNeW;byQBuhGGfn)FAnd4~p)^82_+hb5(QQzz|_9Vk1j~=c&mHeGrm83n)0;=2+!+l7WpxXYf3&uJpWm-i&NnzJnreB4XUDJSLWw&4rE`qV3ZRB)~4okHh0ya2T$@85hajXB6)}*FnedvOkxYX;-R<}{-Vj$mLW(pIEI$d3mFjLJ>7*(d&qS}Og8f;Bp^74g$l}BUOtzrNG&3BnLVnua-ymi6>)U*53U4?fcs@>+&y;rlvS`Ydkp*%rI&=V~X@2Vy%>I#pBnkTfS?wh%wNWoS_n0DW^MN@0%LqYD%H!VP9s9vg~4@}43NKYWm3{>!NMv;8P^0CoKa4Ppog!xW)|xbFI5;gIk3x%$#wS_I5xCR^4)gxF@d}5Y*LI}nsDy3rlue2o^Q~KThgL)x8njeGn9K*g{yMFv@yBkYkW!4S7^j)Q*(!ndqO#Lxad4(B_uVW>w?Xi&*1}W*$;KFR6i%?pKHBvFsdSe$>s7R5pWZw9D_39ANUGfz03rYVdg&#A=>0jQslm$)J`-vIEBPJp!=`IK1&{7EHtwb{VCx^*J0h+52Mld6@i=TP}6~YE-T$nRqdl_n&nwVm@w)0QmCkJ<~$ivH)p!h+CHGY=VE($)$cnbt6bT=Ua`ftH}W6yZfJYt2g0O3RQ`c?K+_4%eX*jMwJfFHO8c^LU}SPjn6Ab?JCKDa7MWp%{`oIV%z+;b3x-07%FCwL^S$A|C)8w&APim?b`c*=IAcsqxfDdlLqkd%8@Qw(XeJ*8Cy!|9Io^j%L;fzej+YQ4AHJU^@N{DokzWI{qLSlTA=g@1gJbD)z93jWaicfE)Y#VGd*;uG+klOwvqe~?Fsm8uoD*&PV8k+3{;-OCaH(^Gz32MepB4h_jm`~l8%u(}bW&4bi@%qsMJF=2`nIOMkeQ>MLV3AIxZEF^`>`Z8|E*h&{m~D{x*?K%<2rnXf_Nh8PST3wJy(A>cPdHClyG*3%M_WxDI2Bb#swkosxi-4CGXN6{kjaO+F8}+c`l%i*c%Q4F&@^i2e7Z}a{6MCSTMsR+Ub2fyt;_EqpeX_Cvdoxxd;Mh$F+_6xXu~2X`ooBxVujS>QDNulUhiw31REJ4%k6(b+p1~GHgnve~6?|_{wwwh~)`*Ntf4txtt&>*s=ZD*|T-?wRM|zA33IGV`qeX0`<)MKh7ICwT&f!`q3ivE|GuXT>={H7NgYp{7N#mxuY`!_6e^~Qq#eW_9mYoaUxsD$-EI^HqdIsX<8_(q)7kt9RlPX!0p5p;(P#D1l#X@u~o7-HF3=gbIefEouM%Pt0j%~nbD1A61J26$hOkRst(N0tEt4?ucZ=HA#Wyh%q?jx?Y*ou=J9sRg9~5m|mF9i;Z^d_KrQ5cYh8ee-~QSZ4e>V>~<@+h8{;3{#-lRm>F3NA|x1+L*MO5^_?P416ShfHDKYeKCsZo7Kd(15sRvlLk@KlXxN8i>0(oD?%w-#NyHGIgqdq#IfSDCyG#MU!+rP78@vm$gg&&(LYOOhU16q?bJVoHUBDU{%d3-bvp&Cbob{{bTgA-7(Sp&SEGw2h6?jw-e+oHnLsSM5Xq2Yf@W3XQE1{miS`%Cb}JKnznI(yW9%!|9hO~7T!w!sky!{psV^$GWjuZ^&|m^UIUGrk*WB6LMON$Aw50NLYr4+(Ezm$X}n4IH(tYSOO_!cg_ICT6F@snlNIXjF@YHT=t}-UfyMCQh$qKNg;U`vL}o{lL2Ei;W0xgIyC%kxn{#1|AZOO@o`C@3g>q-A}p1+bv#;R9>RlDLHOrWJbub7jK}cuRh1JcME+Gv-UzWxdCmxq+o&*Rn46ZwXS_&a*{;SC8^QUGS2~+R(KnKvuTI~u}hqe1ch?~!jl@)t=aD;cLaE+2lTqy4Cp|}WBC%$_Qcmf{hE`W&dY#n%gUWES5%+m6TUsFPn0I*4_ytMYYLVnvRsw;+>{C>0;4xZH$Y^8j9sFQVZ&+UL3^qARZQDH%JWb`Jh-5t@KkOwv2Ci_Kj?PDCQ;lU3poDHQf+<$ijk=B=v!qQDQ~GSYI7ufpXvvDy-7eGO}#fRq*0a`A+k=y+C^pCiZB3_M4k`CqkRm+UU*GWTLCX&9Be0})vO9rWg$M&vZya5D0rDr_+MH>KRxRJV!X`-8;hk(yal&IvNUJa(iJ}qA+Y&e!0Qgk<2h!vG0Gj`a5Te7=X1)i<&R_K^zvAPoT&QK*H;sr@)5~G{TKHb&P14jnd_+FkXJ5;*3*J?0MM!|!KLh;GTx}X=j^WX2sjEcl79G5o91CcRxO!vZ){G}B+Bg|I!$OhqRr&4?B?wbij^vPNL0?al$AEVO^dsH@Y##4tHya*AoDw<*42I(I^^+P+Ym(mgT<$(eoHiSf}>i{Ttyk4elgpgX#+5`@s>IAZ@!JvDKh5)vkjgR@uXR2>~CqGR>o@LBCokzg2L)W<{gU=;%Yu2Ap&mhaT&R`dd?-QYg6~b_TF$YY8cCpZW*sV1P;EUI8$C;oR8dH!q(4F8_Hkeh-jwX@C{9JR*)T#bEzADcd@WJv1YA_GU<1A`vR=#8R0A_aBDs-&LMkUzW~VH1@GgLGVrpqx3{>~5z9T5~4?U^ClUr>1u2yFrb*`??`ll(;{SBGJi3>_0riYQ(Y9DKXDA9D*o;0E&Le{f~jD9HGdCGBpGE%{?WUGXzGqex?r-Gpk;mLb5GPtNBMx`FyVq6FjylfDeLk^vKOW(rD8>BfyF+KUq(`U_Z|BZ2Vl&CYcSl0qVPU?!o_^6)wPP~1A(&7M63#_*t`u^e;NrLd#kdLJHx!QCDq&f$}Hk#@+WSw80avfHWfw?M2abXHZcONM{C1RQJ-niywDNc>yNZsT$@HyMK3q`=9eOMwJVzi-`M_k3!_CQV@Cc3=;Y-=hF+?i)ssH@X^Mea}*^^kbN_yE2`uM0;9n>elU!V_pb%@eRav3iRkiBV{x=u8$axC<$;TGVm#brlkdFTyRK{*ozR3mKg&z7i}cx6Os7UCg-KHl1?wXd>|W3=5k_;n}e1(S3*$RCpN26oi?FBYbzR{Jrsr(tnkPPGiGa<)*x14m&cDI5rt$?JyKTYVXZ0MsXf7MIPPPV@R@CL&@qODfX5fgk5uqlZVW50S$^^?=U6@x^Tqk?NR4VO~O=`!{0a~>X#;K!=>tFk@y3KMo98JN%$H6z*EMl)?m$afgC-+E~b2tt4HPNMBh~WP=He>WVf%?yiF=lirc=C#7yq&h#8%$LZd9fr!QnPs~2%+pB5ta@}2(6Ar~`!L(>@9$<3_=65OU0FM$U0dyU%ha?{WMe>x#a85II+98+Y@J>useAjFaMe+;4Ia($UjRCw3w^?2qy-Z<+6=3!qDd?R}|4sO9i3yP%R4;=?g^(WO1KTJzYdc91klhdGpo86V2dpL;{EoZ4=OnTQD3P6*_J8Wu^x0wg@{~H*CK^4Bof0 else 0;num_sequences=(self.num_tokens[si]-1-phase)//self.seq_len;sequence_order=self.rng.permutation(num_sequences);self.start_inds[si]=(phase+sequence_order*self.seq_len).tolist() + def next_batch(self,global_tokens,grad_accum_steps): + device_tokens=global_tokens//(self.world_size*grad_accum_steps);device_batch_size=device_tokens//self.seq_len;remaining=np.array([len(s)for s in self.start_inds],dtype=np.float64);x=torch.empty((device_batch_size,self.seq_len),dtype=torch.int64);y=torch.empty((device_batch_size,self.seq_len),dtype=torch.int64) + for bi in range(device_batch_size): + total=remaining.sum() + if total<=0: + for si in range(len(self.files)):self._reset_shard(si) + remaining=np.array([len(s)for s in self.start_inds],dtype=np.float64);total=remaining.sum() + probs=remaining/total;si=int(self.rng.choice(len(self.files),p=probs));start_ind=self.start_inds[si].pop();remaining[si]-=1;mm=_get_shard_memmap(self.files[si]);window=torch.as_tensor(np.array(mm[start_ind:start_ind+self.seq_len+1],dtype=np.int64));x[bi]=window[:-1];y[bi]=window[1:] + return x.to(self.device,non_blocking=True),y.to(self.device,non_blocking=True) + + + + + + +class RMSNorm(nn.Module): + def __init__(self,eps=None):super().__init__();self.eps=eps + def forward(self,x):return F.rms_norm(x,(x.size(-1),),eps=self.eps) +class CastedLinear(nn.Linear): + def forward(self,x):w=self.weight.to(x.dtype);bias=self.bias.to(x.dtype)if self.bias is not None else None;return F.linear(x,w,bias) +class Rotary(nn.Module): + def __init__(self,dim,base=1e4,train_seq_len=1024,rope_dims=0):super().__init__();self.dim=dim;self.base=base;self.train_seq_len=train_seq_len;self.rope_dims=rope_dims if rope_dims>0 else dim;inv_freq=1./base**(torch.arange(0,self.rope_dims,2,dtype=torch.float32)/self.rope_dims);self.register_buffer('inv_freq',inv_freq,persistent=False);self._seq_len_cached=0;self._cos_cached=None;self._sin_cached=None + def forward(self,seq_len,device,dtype): + if self._cos_cached is None or self._sin_cached is None or self._seq_len_cached!=seq_len or self._cos_cached.device!=device: + rd=self.rope_dims + if seq_len>self.train_seq_len:scale=seq_len/self.train_seq_len;new_base=self.base*scale**(rd/(rd-2));inv_freq=1./new_base**(torch.arange(0,rd,2,dtype=torch.float32,device=device)/rd) + else:inv_freq=self.inv_freq.to(device) + t=torch.arange(seq_len,device=device,dtype=inv_freq.dtype);freqs=torch.outer(t,inv_freq);self._cos_cached=freqs.cos()[None,:,None,:];self._sin_cached=freqs.sin()[None,:,None,:];self._seq_len_cached=seq_len + return self._cos_cached.to(dtype=dtype),self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x,cos,sin,rope_dims=0): + if rope_dims>0 and rope_dims0 else None;self.smear_gate=CastedLinear(h.smear_gate_width,1,bias=False) if h.smear_gate_enabled else None;_=setattr(self.smear_gate,"_zero_init",True) if self.smear_gate is not None else None;self.smear_lambda=nn.Parameter(torch.zeros(1,dtype=torch.float32)) if h.smear_gate_enabled else None;self.smear_gate_width=h.smear_gate_width if h.smear_gate_enabled else 0;self.register_buffer('logit_temp',torch.tensor(1.0,dtype=torch.float32)) + if h.embedding_dim!=h.model_dim:self.embed_proj=CastedLinear(h.embedding_dim,h.model_dim,bias=False);self.head_proj=CastedLinear(h.model_dim,h.embedding_dim,bias=False) + else:self.embed_proj=None;self.head_proj=None + self.num_encoder_layers=h.num_layers//2;self.num_decoder_layers=h.num_layers-self.num_encoder_layers;self.blocks=nn.ModuleList([Block(h.model_dim,h.num_heads,h.num_kv_heads,h.mlp_mult,h.rope_base,h.qk_gain_init,h.train_seq_len,layer_idx=i,ln_scale=h.ln_scale)for i in range(h.num_layers)]) + if h.rope_dims>0: + head_dim=h.model_dim//h.num_heads + for block in self.blocks:block.attn.rope_dims=h.rope_dims;block.attn.rotary=Rotary(head_dim,base=h.rope_base,train_seq_len=h.train_seq_len,rope_dims=h.rope_dims) + self.final_norm=RMSNorm();self.lm_head=None if h.tie_embeddings else CastedLinear(h.embedding_dim,h.vocab_size,bias=False) + if self.lm_head is not None:self.lm_head._zero_init=True + if h.xsa_last_n>0: + for i in range(max(0,h.num_layers-h.xsa_last_n),h.num_layers):self.blocks[i].attn.use_xsa=True + if h.parallel_residual_start>=0: + for i in range(h.parallel_residual_start,h.num_layers):self.blocks[i].parallel=True + if h.gate_attn_out: + for block in self.blocks: + block.attn.gate_attn_out=True + block.attn.gate_width=h.gate_width + block.attn.gate_attn_src=h.gate_attn_src + block.attn.attn_gate_proj=CastedLinear(h.gate_width,h.num_heads,bias=False) + block.attn.attn_gate_proj._zero_init=True + block.attn.attn_gate_proj.float() + self.looping_active=False + if h.num_loops>0: + loop_seg=list(range(h.loop_start,h.loop_end+1));all_indices=list(range(h.loop_start)) + for _ in range(h.num_loops+1):all_indices.extend(loop_seg) + all_indices.extend(range(h.loop_end+1,h.num_layers));num_enc=len(all_indices)//2;self.encoder_indices=all_indices[:num_enc];self.decoder_indices=all_indices[num_enc:] + else:self.encoder_indices=list(range(self.num_encoder_layers));self.decoder_indices=list(range(self.num_encoder_layers,h.num_layers)) + self.num_skip_weights=min(len(self.encoder_indices),len(self.decoder_indices));self.skip_weights=nn.Parameter(torch.ones(self.num_skip_weights,h.model_dim,dtype=torch.float32));self.skip_gates=nn.Parameter(torch.zeros(self.num_skip_weights,h.model_dim,dtype=torch.float32))if h.skip_gates_enabled else None + self._init_weights() + def _init_weights(self): + if self.tie_embeddings:nn.init.normal_(self.tok_emb.weight,mean=.0,std=self.tied_embed_init_std) + for(name,module)in self.named_modules(): + if isinstance(module,nn.Linear): + if getattr(module,'_zero_init',False):nn.init.zeros_(module.weight) + elif module.weight.ndim==2 and module.weight.shape[0]>=64 and module.weight.shape[1]>=64:nn.init.orthogonal_(module.weight,gain=1.) + def forward_logits(self,input_ids): + x=self.tok_emb(input_ids) + if self.bigram is not None:x=x+self.bigram(input_ids) + x=F.rms_norm(x,(x.size(-1),)) + if self.smear_gate is not None: + x_prev=torch.cat([torch.zeros_like(x[:,:1]),x[:,:-1]],dim=1) + lam=self.smear_lambda.to(dtype=x.dtype) + g=torch.sigmoid(self.smear_gate(x[...,:self.smear_gate_width])) + x=x+lam*g*x_prev + if self.embed_proj is not None:x=self.embed_proj(x) + x0=x;skips=[];enc_iter=self.encoder_indices if self.looping_active else range(self.num_encoder_layers);dec_iter=self.decoder_indices if self.looping_active else range(self.num_encoder_layers,self.num_encoder_layers+self.num_decoder_layers) + for i in enc_iter:x=self.blocks[i](x,x0);skips.append(x) + for(skip_idx,i)in enumerate(dec_iter): + if skip_idxG.size(1) + if transposed:X=X.T + for _ in range(steps):A=X@X.T;B=b*A+c*A@A;X=a*X+B@X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self,params,lr,momentum,backend_steps,nesterov=True,weight_decay=.0,row_normalize=False):super().__init__(params,dict(lr=lr,momentum=momentum,backend_steps=backend_steps,nesterov=nesterov,weight_decay=weight_decay,row_normalize=row_normalize)) + @torch.no_grad() + def step(self,closure=None): + loss=None + if closure is not None: + with torch.enable_grad():loss=closure() + distributed=dist.is_available()and dist.is_initialized();world_size=dist.get_world_size()if distributed else 1;rank=dist.get_rank()if distributed else 0 + for group in self.param_groups: + params=group['params'] + if not params:continue + lr=group['lr'];momentum=group['momentum'];backend_steps=group['backend_steps'];nesterov=group['nesterov'];total_params=sum(int(p.numel())for p in params);updates_flat=torch.zeros(total_params,device=params[0].device,dtype=torch.bfloat16);curr=0 + for(i,p)in enumerate(params): + if i%world_size==rank and p.grad is not None: + g=p.grad;state=self.state[p] + if'momentum_buffer'not in state:state['momentum_buffer']=torch.zeros_like(g) + buf=state['momentum_buffer'];buf.mul_(momentum).add_(g) + if nesterov:g=g.add(buf,alpha=momentum) + if group.get('row_normalize',False):row_norms=g.float().norm(dim=-1,keepdim=True).clamp_min(1e-07);g=g/row_norms.to(g.dtype) + g=zeropower_via_newtonschulz5(g,steps=backend_steps);g*=max(1,g.size(0)/g.size(1))**.5;updates_flat[curr:curr+p.numel()]=g.reshape(-1) + curr+=p.numel() + if distributed:dist.all_reduce(updates_flat,op=dist.ReduceOp.SUM) + wd=group.get('weight_decay',.0);curr=0 + for p in params: + if wd>.0:p.data.mul_(1.-lr*wd) + g=updates_flat[curr:curr+p.numel()].view_as(p).to(dtype=p.dtype);p.add_(g,alpha=-lr);curr+=p.numel() + return loss +CONTROL_TENSOR_NAME_PATTERNS=tuple(pattern for pattern in os.environ.get('CONTROL_TENSOR_NAME_PATTERNS','attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates,bigram.scale,smear_lambda,readout_delta,attn_gate_proj,smear_gate').split(',')if pattern) +class Optimizers: + def __init__(self,h,base_model): + block_named_params=list(base_model.blocks.named_parameters());matrix_params=[p for(name,p)in block_named_params if p.ndim==2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)];scalar_params=[p for(name,p)in block_named_params if p.ndim<2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)] + if base_model.skip_weights.numel()>0:scalar_params.append(base_model.skip_weights) + top_named=dict(base_model.named_parameters()) + for name,p in top_named.items(): + if "blocks." in name:continue + if name=="tok_emb.weight":continue + if "bigram.embed" in name:continue # tok group below + elif "bigram.proj" in name:matrix_params.append(p) + elif "bigram.scale" in name:scalar_params.append(p) + elif "smear_gate" in name and "lambda" not in name:matrix_params.append(p) + elif "smear_lambda" in name:scalar_params.append(p) + if base_model.skip_gates is not None and base_model.skip_gates.numel()>0:scalar_params.append(base_model.skip_gates) + token_lr=h.tied_embed_lr if h.tie_embeddings else h.embed_lr;tok_params=[{'params':[base_model.tok_emb.weight],'lr':token_lr,'base_lr':token_lr}];self.optimizer_tok=torch.optim.AdamW(tok_params,betas=(h.beta1,h.beta2),eps=h.adam_eps,weight_decay=h.embed_wd,fused=True);self.optimizer_muon=Muon(matrix_params,lr=h.matrix_lr,momentum=h.muon_momentum,backend_steps=h.muon_backend_steps,weight_decay=h.muon_wd,row_normalize=h.muon_row_normalize) + for group in self.optimizer_muon.param_groups:group['base_lr']=h.matrix_lr + self.optimizer_scalar=torch.optim.AdamW([{'params':scalar_params,'lr':h.scalar_lr,'base_lr':h.scalar_lr}],betas=(h.beta1,h.beta2),eps=h.adam_eps,weight_decay=h.adam_wd,fused=True);self.optimizers=[self.optimizer_tok,self.optimizer_muon,self.optimizer_scalar] + if base_model.lm_head is not None:self.optimizer_head=torch.optim.Adam([{'params':[base_model.lm_head.weight],'lr':h.head_lr,'base_lr':h.head_lr}],betas=(h.beta1,h.beta2),eps=h.adam_eps,fused=True);self.optimizers.insert(1,self.optimizer_head) + else:self.optimizer_head=None + def __iter__(self):return iter(self.optimizers) + def zero_grad_all(self): + for opt in self.optimizers:opt.zero_grad(set_to_none=True) + def step(self): + for opt in self.optimizers:opt.step() + self.zero_grad_all() +def restore_fp32_params(model): + for module in model.modules(): + if isinstance(module,CastedLinear):module.float() + for(name,param)in model.named_parameters(): + if(param.ndim<2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS))and param.dtype!=torch.float32:param.data=param.data.float() +def collect_hessians(model,train_loader,h,device,n_calibration_batches=64): + hessians={};hooks=[] + def make_hook(name): + def hook_fn(module,inp,out): + x=inp[0].detach().float() + if x.ndim==3:x=x.reshape(-1,x.shape[-1]) + if name not in hessians:hessians[name]=torch.zeros(x.shape[1],x.shape[1],dtype=torch.float32,device=device) + hessians[name].addmm_(x.T,x) + return hook_fn + for(name,module)in model.named_modules(): + if isinstance(module,CastedLinear)and module.weight.numel()>65536: + cat=classify_param(name+'.weight') + if cat in('mlp','attn'):hooks.append(module.register_forward_hook(make_hook(name+'.weight'))) + if model.tie_embeddings: + hook_module=model.head_proj if model.head_proj is not None else model.final_norm + def make_output_hook(name): + def hook_fn(module,inp,out): + x=out.detach().float() + if x.ndim==3:x=x.reshape(-1,x.shape[-1]) + if name not in hessians:hessians[name]=torch.zeros(x.shape[1],x.shape[1],dtype=torch.float32,device=device) + hessians[name].addmm_(x.T,x) + return hook_fn + hooks.append(hook_module.register_forward_hook(make_output_hook('tok_emb.weight'))) + model.eval() + with torch.no_grad(): + for _ in range(n_calibration_batches):x,_=train_loader.next_batch(h.train_batch_tokens,h.grad_accum_steps);model.forward_logits(x) + for hook in hooks:hook.remove() + for name in hessians:hessians[name]=hessians[name].cpu()/n_calibration_batches + return hessians +def gptq_quantize_weight(w,H,clip_sigmas=3.,clip_range=63,block_size=128): + W_orig=w.float().clone();rows,cols=W_orig.shape;H=H.float().clone();dead=torch.diag(H)==0;H[dead,dead]=1;damp=.01*H.diag().mean();H.diagonal().add_(damp);perm=torch.argsort(H.diag(),descending=True);invperm=torch.argsort(perm);W_perm=W_orig[:,perm].clone();W_perm[:,dead[perm]]=0;H=H[perm][:,perm];Hinv=torch.cholesky_inverse(torch.linalg.cholesky(H));Hinv=torch.linalg.cholesky(Hinv,upper=True);row_std=W_orig.std(dim=1);s=(clip_sigmas*row_std/clip_range).clamp_min(1e-10).to(torch.float16);sf=s.float();Q=torch.zeros(rows,cols,dtype=torch.int8);W_work=W_perm.clone() + for i1 in range(0,cols,block_size): + i2=min(i1+block_size,cols);W_block=W_work[:,i1:i2].clone();Hinv_block=Hinv[i1:i2,i1:i2];Err=torch.zeros(rows,i2-i1) + for j in range(i2-i1):w_col=W_block[:,j];d=Hinv_block[j,j];q_col=torch.clamp(torch.round(w_col/sf),-clip_range,clip_range);Q[:,i1+j]=q_col.to(torch.int8);err=(w_col-q_col.float()*sf)/d;Err[:,j]=err;W_block[:,j:]-=err.unsqueeze(1)*Hinv_block[j,j:].unsqueeze(0) + if i21 and any(k in name for k in _FORCE_INT8_PT): + ma=t.abs().max().clamp_min(1e-10);sc=(ma/127.).float();q=torch.clamp(torch.round(t/sc),-127,127).to(torch.int8) + result[name+'.q_pt']=q;result[name+'.scale_pt']=sc;meta[name]='pertensor int8 (control)';continue + if t.is_floating_point()and t.ndim==2 and any(k in name for k in _FORCE_INT8_SMALL): + rm=t.abs().amax(dim=1,keepdim=True).clamp_min(1e-10);s=(rm/127.).squeeze(-1).to(torch.float16);sf=s.float().view(-1,1) + q=torch.clamp(torch.round(t/sf),-127,127).to(torch.int8);result[name+'.q']=q;result[name+'.scale']=s;meta[name]='simple int8 (small matrix)';continue + result[name]=t.to(torch.float16)if t.is_floating_point()else t;meta[name]='passthrough (float16)';continue + if 'bigram.embed' in name: + bits=6;qmax=2**(bits-1)-1;row_max=t.abs().amax(dim=1,keepdim=True).clamp_min(1e-10);s=(row_max/qmax).squeeze(-1).to(torch.float16);sf=s.float().view(-1,1);q=torch.clamp(torch.round(t/sf),-qmax,qmax).to(torch.int8);result[name+'.q']=q;result[name+'.scale']=s;meta[name]=f'simple int{bits} (bigram embed)';continue + cs=h.embed_clip_sigmas if'tok_emb'in name else h.matrix_clip_sigmas;bits=h.embed_bits if'tok_emb'in name else h.matrix_bits;q,s=gptq_quantize_weight(t,hessians[name],clip_sigmas=cs,clip_range=2**(bits-1)-1);result[name+'.q']=q;result[name+'.scale']=s;meta[name]=f"gptq (int{bits})" + categories=collections.defaultdict(set) + for(name,cat)in meta.items():short=re.sub('\\.\\d+$','',re.sub('blocks\\.\\d+','blocks',name));categories[cat].add(short) + log('Quantized weights:') + for cat in sorted(categories):log(f" {cat}: {", ".join(sorted(categories[cat]))}") + return result,meta +def dequantize_mixed(result,meta,template_sd): + out={} + for(name,orig)in template_sd.items(): + info=meta.get(name) + if info is None:continue + orig_dtype=orig.dtype + if'passthrough'in info: + t=result[name] + if t.dtype==torch.float16 and orig_dtype in(torch.float32,torch.bfloat16):t=t.to(orig_dtype) + out[name]=t;continue + if'pertensor'in info: + q=result[name+'.q_pt'];sc=result[name+'.scale_pt'] + out[name]=(q.float()*sc.float()).to(orig_dtype);continue + q,s=result[name+'.q'],result[name+'.scale'] + if s.ndim>0:out[name]=(q.float()*s.float().view(q.shape[0],*[1]*(q.ndim-1))).to(orig_dtype) + else:out[name]=(q.float()*float(s.item())).to(orig_dtype) + return out +_BSHF_MAGIC=b'BSHF' +def _byte_shuffle(data,stride=2): + if stride<=1 or len(data)0: + base_model.train();chunk_seqs=(chunk_end-chunk_start)//seq_len + if chunk_seqs>0: + cos_lr=h.ttt_lr*.5*(1.+math.cos(math.pi*ci/max(num_chunks-1,1))) + for pg in optimizer.param_groups:pg['lr']=cos_lr + my_seq_s=chunk_seqs*rank//world_size;my_seq_e=chunk_seqs*(rank+1)//world_size;my_chunk_seqs=my_seq_e-my_seq_s + for _ep in range(h.ttt_epochs): + for bs in range(0,my_chunk_seqs,batch_seqs): + be=min(bs+batch_seqs,my_chunk_seqs);actual_bs=my_seq_s+bs;start_tok=chunk_start+actual_bs*seq_len;end_tok=chunk_start+(my_seq_s+be)*seq_len+1 + if end_tok>val_data.val_tokens.numel():continue + local=val_data.val_tokens[start_tok:end_tok].to(device=device,dtype=torch.int64);x=local[:-1].reshape(-1,seq_len);y=local[1:].reshape(-1,seq_len);optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type='cuda',dtype=torch.bfloat16):loss=base_model(x,y) + loss.backward() + if world_size>1: + for p in ttt_params: + if p.grad is not None:dist.all_reduce(p.grad,op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params,1.);optimizer.step() + if dist.is_available()and dist.is_initialized():dist.all_reduce(loss_sum,op=dist.ReduceOp.SUM);dist.all_reduce(token_count,op=dist.ReduceOp.SUM);dist.all_reduce(byte_count,op=dist.ReduceOp.SUM) + for p in base_model.parameters():p.requires_grad_(True) + base_model.eval();return _loss_bpb(loss_sum,token_count,byte_count) +def timed_eval(label,fn,*args,**kwargs):torch.cuda.synchronize();t0=time.perf_counter();val_loss,val_bpb=fn(*args,**kwargs);torch.cuda.synchronize();elapsed_ms=1e3*(time.perf_counter()-t0);log(f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms");return val_loss,val_bpb +def train_model(h,device,val_data): + base_model=GPT(h).to(device).bfloat16();restore_fp32_params(base_model);compiled_model=torch.compile(base_model,dynamic=False,fullgraph=True) + if h.distributed:model=DDP(compiled_model,device_ids=[h.local_rank],broadcast_buffers=False) + else:model=compiled_model + log(f"model_params:{sum(p.numel()for p in base_model.parameters())}");optimizers=Optimizers(h,base_model);train_loader=ShuffledSequenceLoader(h,device);max_wallclock_ms=1e3*h.max_wallclock_seconds if h.max_wallclock_seconds>0 else None + if max_wallclock_ms is not None:max_wallclock_ms-=h.gptq_reserve_seconds*1e3;log(f"gptq:reserving {h.gptq_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms") + def training_frac(step,elapsed_ms): + if max_wallclock_ms is None:return step/max(h.iterations,1) + return elapsed_ms/max(max_wallclock_ms,1e-09) + def lr_mul(frac): + if h.warmdown_frac<=0:return 1. + if frac>=1.-h.warmdown_frac:return max((1.-frac)/h.warmdown_frac,h.min_lr) + return 1. + def step_fn(step,lr_scale): + optimizers.zero_grad_all();train_loss=torch.zeros((),device=device) + for micro_step in range(h.grad_accum_steps): + if h.distributed:model.require_backward_grad_sync=micro_step==h.grad_accum_steps-1 + x,y=train_loader.next_batch(h.train_batch_tokens,h.grad_accum_steps) + with torch.autocast(device_type='cuda',dtype=torch.bfloat16,enabled=True):loss=model(x,y) + train_loss+=loss.detach();(loss/h.grad_accum_steps).backward() + train_loss/=h.grad_accum_steps;frac=min(step/h.muon_momentum_warmup_steps,1.)if h.muon_momentum_warmup_steps>0 else 1.;muon_momentum=(1-frac)*h.muon_momentum_warmup_start+frac*h.muon_momentum + for group in optimizers.optimizer_muon.param_groups:group['momentum']=muon_momentum + for opt in optimizers: + for group in opt.param_groups:group['lr']=group['base_lr']*lr_scale + if h.grad_clip_norm>0:torch.nn.utils.clip_grad_norm_(base_model.parameters(),h.grad_clip_norm) + optimizers.step();return train_loss + if h.warmup_steps>0: + initial_model_state={name:tensor.detach().cpu().clone()for(name,tensor)in base_model.state_dict().items()};initial_optimizer_states=[copy.deepcopy(opt.state_dict())for opt in optimizers];model.train() + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step,1.) + if warmup_step<=5 or(warmup_step+1)%10==0 or warmup_step+1==h.warmup_steps:log(f"warmup_step: {warmup_step+1}/{h.warmup_steps}") + if h.num_loops>0: + base_model.looping_active=True;log(f"loop_warmup:enabled encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}") + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step,1.) + if warmup_step<=5 or(warmup_step+1)%10==0 or warmup_step+1==h.warmup_steps:log(f"loop_warmup_step: {warmup_step+1}/{h.warmup_steps}") + base_model.looping_active=False + base_model.load_state_dict(initial_model_state,strict=True) + for(opt,state)in zip(optimizers,initial_optimizer_states,strict=True):opt.load_state_dict(state) + optimizers.zero_grad_all() + if h.distributed:model.require_backward_grad_sync=True + train_loader=ShuffledSequenceLoader(h,device) + ema_state={name:t.detach().float().clone()for(name,t)in base_model.state_dict().items()};ema_decay=h.ema_decay;training_time_ms=.0;stop_after_step=None;torch.cuda.synchronize();t0=time.perf_counter();step=0 + while True: + last_step=step==h.iterations or stop_after_step is not None and step>=stop_after_step;should_validate=last_step or h.val_loss_every>0 and step%h.val_loss_every==0 + if should_validate:torch.cuda.synchronize();training_time_ms+=1e3*(time.perf_counter()-t0);val_loss,val_bpb=eval_val(h,device,val_data,model);log(f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}");torch.cuda.synchronize();t0=time.perf_counter() + if last_step: + if stop_after_step is not None and step0 and not base_model.looping_active and frac>=h.enable_looping_at:base_model.looping_active=True;log(f"layer_loop:enabled step:{step} frac:{frac:.3f} encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}") + train_loss=step_fn(step,scale) + with torch.no_grad(): + for(name,t)in base_model.state_dict().items():ema_state[name].mul_(ema_decay).add_(t.detach().float(),alpha=1.-ema_decay) + step+=1;approx_training_time_ms=training_time_ms+1e3*(time.perf_counter()-t0);should_log_train=h.train_log_every>0 and(step<=5 or step%h.train_log_every==0 or stop_after_step is not None) + if should_log_train:tok_per_sec=step*h.train_batch_tokens/(approx_training_time_ms/1e3);log(f"{step}/{h.iterations} train_loss: {train_loss.item():.4f} train_time: {approx_training_time_ms/60000:.1f}m tok/s: {tok_per_sec:.0f}") + reached_cap=max_wallclock_ms is not None and approx_training_time_ms>=max_wallclock_ms + if h.distributed and max_wallclock_ms is not None:reached_cap_tensor=torch.tensor(int(reached_cap),device=device);dist.all_reduce(reached_cap_tensor,op=dist.ReduceOp.MAX);reached_cap=bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap:stop_after_step=step + log(f"peak memory allocated: {torch.cuda.max_memory_allocated()//1024//1024} MiB reserved: {torch.cuda.max_memory_reserved()//1024//1024} MiB");log('ema:applying EMA weights');current_state=base_model.state_dict();avg_state={name:t.to(dtype=current_state[name].dtype)for(name,t)in ema_state.items()};base_model.load_state_dict(avg_state,strict=True);return base_model,compiled_model +def train_and_eval(h,device): + random.seed(h.seed);np.random.seed(h.seed);torch.manual_seed(h.seed);torch.cuda.manual_seed_all(h.seed);val_data=ValidationData(h,device);log(f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob("fineweb_train_*.bin")))}");log(f"val_tokens: {val_data.val_tokens.numel()-1}");base_model,compiled_model=train_model(h,device,val_data);torch._dynamo.reset();timed_eval('pre-quantization post-ema',eval_val,h,device,val_data,compiled_model);serialize(h,base_model,Path(__file__).read_text(encoding='utf-8')) + if h.distributed:dist.barrier() + eval_model=deserialize(h,device) + if h.num_loops>0:eval_model.looping_active=True + compiled_model=torch.compile(eval_model,dynamic=False,fullgraph=True);timed_eval('quantized',eval_val,h,device,val_data,compiled_model) + if h.sliding_window_enabled:timed_eval('quantized_sliding_window',eval_val_sliding,h,device,val_data,eval_model) + if h.ttt_enabled and h.sliding_window_enabled: + del eval_model,compiled_model;torch._dynamo.reset();torch.cuda.empty_cache();ttt_model=deserialize(h,device) + if h.num_loops>0:ttt_model.looping_active=True + timed_eval('quantized_ttt',eval_val_ttt,h,device,val_data,ttt_model);del ttt_model + if h.etlb_enabled and h.sliding_window_enabled: + if'eval_model'not in dir(): + eval_model=deserialize(h,device) + if h.num_loops>0:eval_model.looping_active=True + timed_eval('quantized_sliding_etlb',eval_val_sliding_etlb,h,device,val_data,eval_model) +def main(): + world_size=int(os.environ.get('WORLD_SIZE','1'));local_rank=int(os.environ.get('LOCAL_RANK','0'));distributed='RANK'in os.environ and'WORLD_SIZE'in os.environ + if not torch.cuda.is_available():raise RuntimeError('CUDA is required') + if world_size<=0:raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8%world_size!=0:raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + device=torch.device('cuda',local_rank);torch.cuda.set_device(device) + if distributed:dist.init_process_group(backend='nccl',device_id=device);dist.barrier() + torch.backends.cuda.matmul.allow_tf32=True;torch.backends.cudnn.allow_tf32=True;torch.set_float32_matmul_precision('high');from torch.backends.cuda import enable_cudnn_sdp,enable_flash_sdp,enable_math_sdp,enable_mem_efficient_sdp;enable_cudnn_sdp(False);enable_flash_sdp(True);enable_mem_efficient_sdp(False);enable_math_sdp(False);torch._dynamo.config.optimize_ddp=False;h=Hyperparameters();set_logging_hparams(h) + if h.is_main_process: + os.makedirs('logs',exist_ok=True);log(100*'=',console=False);log('Hyperparameters:',console=True) + for(k,v)in sorted(vars(type(h)).items()): + if not k.startswith('_'):log(f" {k}: {v}",console=True) + log('='*100,console=False);log(f"Running Python {sys.version}",console=False);log(f"Running PyTorch {torch.__version__}",console=False);log(subprocess.run(['nvidia-smi'],stdout=subprocess.PIPE,stderr=subprocess.PIPE,text=True,check=False).stdout,console=False);log('='*100,console=False) + train_and_eval(h,device) + if distributed:dist.destroy_process_group() +if __name__=='__main__':main() \ No newline at end of file diff --git a/records/track_10min_16mb/2026-04-18_SP8192_BigramHash32_PathAv3/train_seed314.log b/records/track_10min_16mb/2026-04-18_SP8192_BigramHash32_PathAv3/train_seed314.log new file mode 100644 index 0000000000..bad6b94f3f --- /dev/null +++ b/records/track_10min_16mb/2026-04-18_SP8192_BigramHash32_PathAv3/train_seed314.log @@ -0,0 +1,162 @@ +W0418 08:36:14.070000 1185822 torch/distributed/run.py:803] +W0418 08:36:14.070000 1185822 torch/distributed/run.py:803] ***************************************** +W0418 08:36:14.070000 1185822 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0418 08:36:14.070000 1185822 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + bigram_dim: 32 + bigram_vocab_size: 16384 + compressor: brotli + data_dir: /workspace/parameter-golf/data/ + datasets_dir: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.9965 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.35 + etlb_clip: 3.0 + etlb_enabled: False + etlb_lr: 0.05 + etlb_steps: 5 + eval_seq_len: 2048 + eval_stride: 64 + gate_attn_out: True + gate_attn_src: proj + gate_width: 12 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 12.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/seed314.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.022 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_residual_start: 7 + qk_gain_init: 5.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + readout_groups: 16 + readout_scale: 0.5 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: seed314 + scalar_lr: 0.02 + seed: 314 + skip_gates_enabled: True + sliding_window_enabled: True + smear_gate_enabled: True + smear_gate_width: 12 + temp_cal_batches: 50 + temp_cal_enabled: False + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: /workspace/parameter-golf/data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_chunk_tokens: 32768 + ttt_enabled: True + ttt_epochs: 3 + ttt_lr: 0.005 + ttt_momentum: 0.9 + use_pass_readout: False + val_batch_tokens: 524288 + val_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.72 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +train_shards: 128 +val_tokens: 40542208 +model_params:36486278 +gptq:reserving 12s, effective=588000ms +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0078 val_bpb: 3.4874 +1/20000 train_loss: 9.0053 train_time: 0.0m tok/s: 8121715 +2/20000 train_loss: 12.2468 train_time: 0.0m tok/s: 7968338 +3/20000 train_loss: 11.0931 train_time: 0.0m tok/s: 7868157 +4/20000 train_loss: 9.5595 train_time: 0.0m tok/s: 7813534 +5/20000 train_loss: 8.3092 train_time: 0.0m tok/s: 7789123 +500/20000 train_loss: 3.2920 train_time: 0.9m tok/s: 7478176 +1000/20000 train_loss: 3.2087 train_time: 1.8m tok/s: 7458004 +1500/20000 train_loss: 3.1138 train_time: 2.6m tok/s: 7452878 +layer_loop:enabled step:1949 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +2000/20000 train_loss: 3.0772 train_time: 3.6m tok/s: 7357202 +2500/20000 train_loss: 2.9839 train_time: 4.9m tok/s: 6737424 +3000/20000 train_loss: 3.0110 train_time: 6.2m tok/s: 6379454 +3500/20000 train_loss: 2.9472 train_time: 7.5m tok/s: 6146460 +4000/20000 train_loss: 2.8117 train_time: 8.8m tok/s: 5968274 +4000/20000 val_loss: 2.8492 val_bpb: 1.1031 +4393/20000 val_loss: 2.8076 val_bpb: 1.0870 +stopping_early: wallclock_cap train_time: 588031ms step: 4393/20000 +peak memory allocated: 39506 MiB reserved: 39574 MiB +ema:applying EMA weights +pre-quantization post-ema val_loss:2.80453276 val_bpb:1.08579894 eval_time:6116ms +Serialized model: 136555547 bytes +Code size: 18097 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 13.1s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int8): tok_emb.weight + passthrough (float16): bigram.scale, blocks.attn.q_gain, logit_temp, smear_lambda + pertensor int8 (control): blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, skip_gates, skip_weights + simple int6 (bigram embed): bigram.embed.weight + simple int8 (small matrix): bigram.proj.weight, blocks.attn.attn_gate_proj.weight, smear_gate.weight +Serialized model quantized+brotli: 15976073 bytes +Total submission size quantized+brotli: 15994170 bytes +quantized val_loss:2.83292704 val_bpb:1.09679203 eval_time:24114ms +quantized_sliding_window val_loss:2.79016454 val_bpb:1.08023616 eval_time:119913ms +ttt:start chunks=1238 ttt_lr=0.005 ttt_epochs=3 +quantized_ttt val_loss:2.78678778 val_bpb:1.07892882 eval_time:335468ms diff --git a/records/track_10min_16mb/2026-04-18_SP8192_BigramHash32_PathAv3/train_seed42.log b/records/track_10min_16mb/2026-04-18_SP8192_BigramHash32_PathAv3/train_seed42.log new file mode 100644 index 0000000000..412791ae31 --- /dev/null +++ b/records/track_10min_16mb/2026-04-18_SP8192_BigramHash32_PathAv3/train_seed42.log @@ -0,0 +1,162 @@ +W0418 08:13:00.952000 1129791 torch/distributed/run.py:803] +W0418 08:13:00.952000 1129791 torch/distributed/run.py:803] ***************************************** +W0418 08:13:00.952000 1129791 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0418 08:13:00.952000 1129791 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + bigram_dim: 32 + bigram_vocab_size: 16384 + compressor: brotli + data_dir: /workspace/parameter-golf/data/ + datasets_dir: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.9965 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.35 + etlb_clip: 3.0 + etlb_enabled: False + etlb_lr: 0.05 + etlb_steps: 5 + eval_seq_len: 2048 + eval_stride: 64 + gate_attn_out: True + gate_attn_src: proj + gate_width: 12 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 12.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/validate_seed42.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.022 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_residual_start: 7 + qk_gain_init: 5.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + readout_groups: 16 + readout_scale: 0.5 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: validate_seed42 + scalar_lr: 0.02 + seed: 42 + skip_gates_enabled: True + sliding_window_enabled: True + smear_gate_enabled: True + smear_gate_width: 12 + temp_cal_batches: 50 + temp_cal_enabled: False + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: /workspace/parameter-golf/data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_chunk_tokens: 32768 + ttt_enabled: True + ttt_epochs: 3 + ttt_lr: 0.005 + ttt_momentum: 0.9 + use_pass_readout: False + val_batch_tokens: 524288 + val_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.72 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +train_shards: 128 +val_tokens: 40542208 +model_params:36486278 +gptq:reserving 12s, effective=588000ms +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0074 val_bpb: 3.4873 +1/20000 train_loss: 9.0057 train_time: 0.0m tok/s: 7912372 +2/20000 train_loss: 12.2393 train_time: 0.0m tok/s: 7888650 +3/20000 train_loss: 11.0647 train_time: 0.0m tok/s: 7821158 +4/20000 train_loss: 9.5574 train_time: 0.0m tok/s: 7773211 +5/20000 train_loss: 8.3000 train_time: 0.0m tok/s: 7739746 +500/20000 train_loss: 3.2939 train_time: 0.9m tok/s: 7488686 +1000/20000 train_loss: 3.2082 train_time: 1.8m tok/s: 7469856 +1500/20000 train_loss: 3.1137 train_time: 2.6m tok/s: 7462044 +layer_loop:enabled step:1951 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +2000/20000 train_loss: 3.0782 train_time: 3.6m tok/s: 7368866 +2500/20000 train_loss: 2.9845 train_time: 4.9m tok/s: 6745892 +3000/20000 train_loss: 3.0073 train_time: 6.2m tok/s: 6387016 +3500/20000 train_loss: 2.9421 train_time: 7.5m tok/s: 6143958 +4000/20000 train_loss: 2.8149 train_time: 8.8m tok/s: 5966483 +4000/20000 val_loss: 2.8488 val_bpb: 1.1030 +4393/20000 val_loss: 2.8077 val_bpb: 1.0870 +stopping_early: wallclock_cap train_time: 588110ms step: 4393/20000 +peak memory allocated: 39506 MiB reserved: 39574 MiB +ema:applying EMA weights +pre-quantization post-ema val_loss:2.80464366 val_bpb:1.08584188 eval_time:6074ms +Serialized model: 136555547 bytes +Code size: 18097 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 13.1s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int8): tok_emb.weight + passthrough (float16): bigram.scale, blocks.attn.q_gain, logit_temp, smear_lambda + pertensor int8 (control): blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, skip_gates, skip_weights + simple int6 (bigram embed): bigram.embed.weight + simple int8 (small matrix): bigram.proj.weight, blocks.attn.attn_gate_proj.weight, smear_gate.weight +Serialized model quantized+brotli: 15973106 bytes +Total submission size quantized+brotli: 15991203 bytes +quantized val_loss:2.83289805 val_bpb:1.09678081 eval_time:23829ms +quantized_sliding_window val_loss:2.78993168 val_bpb:1.08014601 eval_time:120470ms +ttt:start chunks=1238 ttt_lr=0.005 ttt_epochs=3 +quantized_ttt val_loss:2.78662485 val_bpb:1.07886574 eval_time:336109ms diff --git a/records/track_10min_16mb/2026-04-18_SP8192_BigramHash32_PathAv3/train_seed999.log b/records/track_10min_16mb/2026-04-18_SP8192_BigramHash32_PathAv3/train_seed999.log new file mode 100644 index 0000000000..3abb54cbec --- /dev/null +++ b/records/track_10min_16mb/2026-04-18_SP8192_BigramHash32_PathAv3/train_seed999.log @@ -0,0 +1,162 @@ +W0418 09:00:01.395000 1241828 torch/distributed/run.py:803] +W0418 09:00:01.395000 1241828 torch/distributed/run.py:803] ***************************************** +W0418 09:00:01.395000 1241828 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0418 09:00:01.395000 1241828 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + bigram_dim: 32 + bigram_vocab_size: 16384 + compressor: brotli + data_dir: /workspace/parameter-golf/data/ + datasets_dir: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.9965 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.35 + etlb_clip: 3.0 + etlb_enabled: False + etlb_lr: 0.05 + etlb_steps: 5 + eval_seq_len: 2048 + eval_stride: 64 + gate_attn_out: True + gate_attn_src: proj + gate_width: 12 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 12.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/seed999.txt + logit_softcap: 30.0 + loop_end: 5 + loop_start: 3 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.022 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.095 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_residual_start: 7 + qk_gain_init: 5.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + readout_groups: 16 + readout_scale: 0.5 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: seed999 + scalar_lr: 0.02 + seed: 999 + skip_gates_enabled: True + sliding_window_enabled: True + smear_gate_enabled: True + smear_gate_width: 12 + temp_cal_batches: 50 + temp_cal_enabled: False + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: /workspace/parameter-golf/data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_chunk_tokens: 32768 + ttt_enabled: True + ttt_epochs: 3 + ttt_lr: 0.005 + ttt_momentum: 0.9 + use_pass_readout: False + val_batch_tokens: 524288 + val_files: /workspace/parameter-golf/data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.72 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +train_shards: 128 +val_tokens: 40542208 +model_params:36486278 +gptq:reserving 12s, effective=588000ms +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_step: 1/20 +loop_warmup_step: 2/20 +loop_warmup_step: 3/20 +loop_warmup_step: 4/20 +loop_warmup_step: 5/20 +loop_warmup_step: 6/20 +loop_warmup_step: 10/20 +loop_warmup_step: 20/20 +0/20000 val_loss: 9.0083 val_bpb: 3.4876 +1/20000 train_loss: 9.0065 train_time: 0.0m tok/s: 8085462 +2/20000 train_loss: 12.2839 train_time: 0.0m tok/s: 7995772 +3/20000 train_loss: 11.1313 train_time: 0.0m tok/s: 7876203 +4/20000 train_loss: 9.6083 train_time: 0.0m tok/s: 7830526 +5/20000 train_loss: 8.3264 train_time: 0.0m tok/s: 7809404 +500/20000 train_loss: 3.2918 train_time: 0.9m tok/s: 7492101 +1000/20000 train_loss: 3.2060 train_time: 1.8m tok/s: 7472803 +1500/20000 train_loss: 3.1122 train_time: 2.6m tok/s: 7464925 +layer_loop:enabled step:1952 frac:0.350 encoder:[0, 1, 2, 3, 4, 5, 3, 4] decoder:[5, 3, 4, 5, 6, 7, 8, 9, 10] +2000/20000 train_loss: 3.0707 train_time: 3.6m tok/s: 7374274 +2500/20000 train_loss: 2.9827 train_time: 4.9m tok/s: 6750658 +3000/20000 train_loss: 3.0066 train_time: 6.2m tok/s: 6390914 +3500/20000 train_loss: 2.9431 train_time: 7.5m tok/s: 6156601 +4000/20000 train_loss: 2.8129 train_time: 8.8m tok/s: 5984704 +4000/20000 val_loss: 2.8500 val_bpb: 1.1034 +4403/20000 val_loss: 2.8071 val_bpb: 1.0868 +stopping_early: wallclock_cap train_time: 588029ms step: 4403/20000 +peak memory allocated: 39506 MiB reserved: 39574 MiB +ema:applying EMA weights +pre-quantization post-ema val_loss:2.80404559 val_bpb:1.08561033 eval_time:6108ms +Serialized model: 136555547 bytes +Code size: 18097 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 13.1s +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int8): tok_emb.weight + passthrough (float16): bigram.scale, blocks.attn.q_gain, logit_temp, smear_lambda + pertensor int8 (control): blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, skip_gates, skip_weights + simple int6 (bigram embed): bigram.embed.weight + simple int8 (small matrix): bigram.proj.weight, blocks.attn.attn_gate_proj.weight, smear_gate.weight +Serialized model quantized+brotli: 15978006 bytes +Total submission size quantized+brotli: 15996103 bytes +quantized val_loss:2.83249036 val_bpb:1.09662297 eval_time:24096ms +quantized_sliding_window val_loss:2.78950296 val_bpb:1.07998003 eval_time:120053ms +ttt:start chunks=1238 ttt_lr=0.005 ttt_epochs=3 +quantized_ttt val_loss:2.78608265 val_bpb:1.07865582 eval_time:333575ms