diff --git a/records/track_10min_16mb/2026-05-01_SP8192_PR2130Base_GW8/README.md b/records/track_10min_16mb/2026-05-01_SP8192_PR2130Base_GW8/README.md new file mode 100644 index 0000000000..69ec8a7a65 --- /dev/null +++ b/records/track_10min_16mb/2026-05-01_SP8192_PR2130Base_GW8/README.md @@ -0,0 +1,193 @@ +# Record candidate: PR #1797 base + Token-only n-gram tilt + AsymLogit Rescale + #2060 levers + NUM_PHASES=1 — GATE_WINDOW=8 + +**Comparison baseline: PR #2130 (1.05670 BPB)** — direct stack baseline. +**Merged-leaderboard comparison: PR #1855 (1.06108 BPB).** + +This submission keeps the PR #2130 architecture/training stack identical and only changes: **GATE_WINDOW=8**. Every other hyperparameter, env var, and code path is byte-for-byte the PR #2130 reproduction command. + +## Full validation coverage + +The 3 per-seed logs evaluate the full CaseOps validation shard target set: + +| Seed | `val_tokens` | `target_tokens` | +|-----:|-------------:|----------------:| +| 314 | 47,851,520 | 47,851,520 | +| 42 | 47,851,520 | 47,851,520 | +| 0 | 47,851,520 | 47,851,520 | + +Validation tokens are 47,851,520 (matches the PR #1855 / PR #1797 / PR #2130 truncation; PR #2014's 47,853,343 figure is methodology, not legality). + +The tokenizer, CaseOps transform, training shards, validation shard, and byte sidecar format are the canonical HF-hosted CaseOps export (`romeerp/parameter-golf-caseops-v1`) used by the merged PR #1855 setup. + +## What changed vs PR #2130 + +| Lever | PR #2130 | This submission | Mechanism | +|-------|----------|-----------------|-----------| +| `GATE_WINDOW` | 12 | **8** | Tightens the SparseAttnGate window from 12 to 8 positions; the gate then attends to fewer (more selective) recent tokens. | + +Everything else — model architecture, optimizer, schedule, TTT, quantization, compression — is byte-for-byte identical to the PR #2130 stack. + +## Architecture and training stack + +| Component | Setting | +|-----------|---------| +| Model | 11 layers, 512d, 8 query heads, 4 KV heads, MLP 4x | +| Tokenizer/data | SP8192 CaseOps lossless caps, byte sidecar accounting (PR #1729 / #1736 lineage) | +| RoPE | Partial RoPE, 16 dims | +| Recurrence | Layers 3-5 looped at frac=0.35 | +| Parallel decoder | Layer 8+ | +| XSA | All 11 layers | +| Gates | BOS-fixed SmearGate, SparseAttnGate (`scale=0.5`) | +| Optimizer | Muon on matrix params (LR=0.028), Adam on embedding/scalars (BETA2=0.99) | +| EMA | EMA_DECAY=0.9965 | +| Quantization | GPTQ int6 matrices, int7 embeddings, LQER asymmetric rank-4 (GROUP=32, TOP_K=3) | +| GPTQ reserve | 0.5s | +| Compression | per-group | +| Eval context | `EVAL_SEQ_LEN=2560`, `TTT_EVAL_SEQ_LEN=2560` | +| TTT | Quantized phased LoRA TTT (RANK=80, LR=8e-5, BETA2=0.99, WEIGHT_DECAY=2.0), score-first, K-off, 1 phase, 2500-doc prefix | +| Logit softcap | AsymLogit Rescale (softcap_pos, softcap_neg, init 30.0, global TTT) | +| Tilt | Token-only n-gram tilt (TOKEN_ORDER=16, TOKEN_THRESHOLD=0.800, TOKEN_BOOST=2.625) | + +## Compliance notes + +- **Artifact cap:** all seeds <= 16,000,000 bytes. +- **Training wallclock:** training stops at the 600s wallclock budget (matching PR #2130). +- **Eval wallclock:** all validation passes are <= 600s. +- **Score-before-update:** `quantized_ttt_phased` scores each chunk before applying that chunk's LoRA update. Inherited unchanged from PR #2130. +- **Full validation targets:** `val_tokens == target_tokens == 47,851,520` in all included logs. +- **No validation data in training:** training uses only training shards. TTT accesses validation documents left-to-right under the score-first rule. +- **No external cache or direct memorization:** no SLOT, persistent n-gram cache, PPM mixture, logit bias table, or validation-derived precomputation. +- **Original-byte BPB:** CaseOps byte sidecar accounting is preserved. + +### Token-only n-gram causality (PR #1514 precedent) + +This base inherits the strictly-causal token n-gram tilt from PR #2130. Within-word and word-start channels are explicitly disabled (`WITHIN_TAU=99.0`, `WORD_TAU=99.0`, `WITHIN_BOOST=0.0`, `WORD_BOOST=0.0`, `AGREE_ADD_BOOST=0.0`), so only the strictly-causal token channel fires. The C1 causality property of `online_ngram_state.c` is preserved and is unaffected by the lever change(s) in this submission. + +### AsymLogit Rescale safety + +`softcap_pos` and `softcap_neg` are nn.Parameter members of the base GPT module, NOT of `BatchedTTTLoRA`. They are adapted by global TTT once over the prefix pass; per-doc TTT touches only the LoRA adapters. ~8 bytes total in the artifact. Unchanged from PR #2130. + +### Closed-form n-gram tilt preserves probability mass + +Per `online_ngram_tilt.py`, the tilt applies `logit += boost` for the gated token, then renormalizes via softmax. Sum of probabilities equals 1 by construction (C2 property required by PR #1514). Unchanged from PR #2130. + +### Eval wallclock — full disclosure + +Total eval times include the n-gram precompute, computed INSIDE the eval timer. `NGRAM_HINT_PRECOMPUTE_OUTSIDE=0` is set explicitly in the launch command. Logs show `ngram_hint_precompute_outside: False` and a `precompute_done` line. + +## Reproduction + +Install the dependencies in `requirements.txt`. FlashAttention 3 and the `lrzip` system binary are noted there because they require separate install paths. + +This submission uses the canonical CaseOps SP8192 export hosted on Hugging Face. Logs are produced from a 50,000-document validation split with 80 training shards. + +Preferred data setup (matches PR #2130): + +```bash +python3 - <<'PY' +from huggingface_hub import snapshot_download +snapshot_download( + repo_id="romeerp/parameter-golf-caseops-v1", + repo_type="dataset", + local_dir="./data/datasets/fineweb10B_sp8192_caseops", + allow_patterns=[ + "datasets/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/*", + "datasets/tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model", + ], + max_workers=8, +) +PY +``` + +Training command (only the lever override differs from PR #2130): + +```bash +for SEED in 314 42 0; do + PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \ + NCCL_NET=Socket \ + DATA_DIR=./data \ + DATA_PATH=./data/datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved \ + TOKENIZER_PATH=./data/tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model \ + CASEOPS_ENABLED=1 \ + VOCAB_SIZE=8192 \ + ASYM_LOGIT_RESCALE=1 \ + NGRAM_TILT_ENABLED=1 \ + NGRAM_HINT_PRECOMPUTE_OUTSIDE=0 \ + TOKEN_ORDER=16 \ + TOKEN_THRESHOLD=0.800 \ + TOKEN_BOOST=2.625 \ + WITHIN_TAU=99.0 WITHIN_BOOST=0.0 \ + WORD_TAU=99.0 WORD_BOOST=0.0 \ + AGREE_ADD_BOOST=0.0 \ + SMEAR_GATE_BOS_FIX=0 \ + TTT_LORA_EMA_DECAY=0.0 \ + TTT_UPDATE_EVERY=1 \ + PHASED_TTT_PREFIX_DOCS=2500 \ + PHASED_TTT_NUM_PHASES=1 \ + EVAL_SEQ_LEN=2560 \ + TTT_EVAL_SEQ_LEN=2560 \ + COMPRESSOR=pergroup \ + MATRIX_CLIP_SIGMAS=12.85 \ + ATTN_CLIP_SIGMAS=13.0 \ + MLP_CLIP_SIGMAS=11.5 \ + EMBED_BITS=7 \ + EMBED_CLIP_SIGMAS=14.0 \ + MATRIX_LR=0.028 \ + MIN_LR=0.1 \ + WARMDOWN_FRAC=0.85 \ + BETA2=0.99 \ + FUSED_CE_ENABLED=1 \ + SPARSE_ATTN_GATE_ENABLED=1 \ + SPARSE_ATTN_GATE_SCALE=0.5 \ + SMEAR_GATE_ENABLED=1 \ + GATE_WINDOW=8 \ + LQER_ENABLED=1 \ + LQER_RANK=4 \ + LQER_TOP_K=3 \ + LQER_FACTOR_BITS=4 \ + LQER_ASYM_ENABLED=1 \ + LQER_ASYM_GROUP=32 \ + TTT_WARM_START_A=1 \ + TTT_LORA_RANK=80 \ + TTT_BETA2=0.99 \ + TTT_WEIGHT_DECAY=2.0 \ + TTT_LORA_LR=8e-5 \ + GPTQ_RESERVE_SECONDS=0.5 \ + GPTQ_CALIBRATION_BATCHES=16 \ + TTT_K_LORA=0 \ + TTT_O_LORA=1 \ + TTT_MLP_LORA=1 \ + EMA_DECAY=0.9965 \ + SEED=$SEED \ + torchrun --standalone --nproc_per_node=8 train_gpt.py \ + > train_seed${SEED}.log 2>&1 +done +``` + + +## Included files + +- `train_gpt.py` - full training/eval script (PR #2130's version). +- `train_seed*.log` - full per-seed logs. +- `submission.json` - structured metadata. +- `README.md` - this file. +- `requirements.txt` - Python dependencies plus notes for FA3 and `lrzip`. +- `prepare_caseops_data.py` - fallback CaseOps dataset/token/byte-sidecar preparation. +- `lossless_caps.py` - reversible CaseOps transform. +- `tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model` - SentencePiece tokenizer. +- `online_ngram_tilt.py` - n-gram tilt Python wrapper. +- `online_ngram_state.c` - n-gram state machine (token-only path used). + +## Lineage and credits + +This submission is a single-knob (resp. two-knob) refinement on top of PR #2130, which itself stacks on the public CaseOps/SP8192 record lineage: + +- PR #2130 by @codemath3000 - direct comparison baseline; token-only n-gram tilt + AsymLogit Rescale + #2060 levers + NUM_PHASES=1 stacked on PR #1797 / PR #1855 lineage. +- PR #1855 by @codemath3000 - merged leaderboard record. +- PR #1797 by @dexhunter - SmearGate and LQER asymmetric rank-4 lineage; direct base of PR #2130. +- PR #1514 (merged) - token-only n-gram tilt legality precedent. +- PR #1923 - AsymLogit Rescale technique. +- PR #2060 by @S0urC10ud - hyperparameter sweep values. +- PR #2014 by @simonbissonnette - PHASED_TTT_NUM_PHASES=1 precedent. + +The new contribution here is the targeted hyperparameter change (GATE_WINDOW=8) on the PR #2130 stack, isolated as a clean ablation so reviewers can compare the lever effect against PR #2130 directly. diff --git a/records/track_10min_16mb/2026-05-01_SP8192_PR2130Base_GW8/lossless_caps.py b/records/track_10min_16mb/2026-05-01_SP8192_PR2130Base_GW8/lossless_caps.py new file mode 100644 index 0000000000..98e472f824 --- /dev/null +++ b/records/track_10min_16mb/2026-05-01_SP8192_PR2130Base_GW8/lossless_caps.py @@ -0,0 +1,833 @@ +"""Lossless capitalization pre-encoding helpers. + +This module provides a narrow, reversible transform that only touches +ASCII capital letters `A-Z`. Each uppercase ASCII letter is rewritten as +``, where `sentinel` is a private-use Unicode +character that is escaped by doubling if it appears literally in the +input text. + +Example with the default sentinel `\\uE000`: + + "The NASA Launch" -> "\\uE000the \\uE000n\\uE000a\\uE000s\\uE000a \\uE000launch" + +The transform is intentionally simple for v1: + +- lowercase ASCII letters are unchanged +- uppercase ASCII letters become sentinel + lowercase letter +- non-ASCII characters are left untouched +- literal sentinel characters are escaped as sentinel + sentinel + +This makes the transform exactly invertible while allowing a downstream +tokenizer to reuse lowercase subwords across case variants. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Callable, Iterable + +LOSSLESS_CAPS_V1 = "lossless_caps_v1" +LOSSLESS_CAPS_V2 = "lossless_caps_v2" +LOSSLESS_CAPS_V3 = "lossless_caps_v3" +LOSSLESS_CAPS_V4 = "lossless_caps_v4" +LOSSLESS_CAPS_V5 = "lossless_caps_v5" +LOSSLESS_CAPS_V6 = "lossless_caps_v6" +LOSSLESS_CAPS_V7 = "lossless_caps_v7" +LOSSLESS_CAPS_CASEOPS_V1 = "lossless_caps_caseops_v1" +IDENTITY = "identity" +DEFAULT_SENTINEL = "\uE000" +DEFAULT_V2_TITLE = "\uE001" +DEFAULT_V2_ALLCAPS = "\uE002" +DEFAULT_V2_CAPNEXT = "\uE003" +DEFAULT_V2_ESC = "\uE004" +DEFAULT_V5_TITLE_MIN_LEN = 7 +DEFAULT_V6_ALLCAPS_MIN_LEN = 3 +DEFAULT_V7_ALLCAPS_MIN_LEN = 4 + + +class LosslessCapsError(ValueError): + """Raised when a transformed string is malformed.""" + + +def _is_ascii_upper(ch: str) -> bool: + return "A" <= ch <= "Z" + + +def _is_ascii_lower(ch: str) -> bool: + return "a" <= ch <= "z" + + +def _is_ascii_alpha(ch: str) -> bool: + return _is_ascii_lower(ch) or _is_ascii_upper(ch) + + +def _validate_distinct_single_chars(*chars: str) -> None: + if any(len(ch) != 1 for ch in chars): + raise ValueError("all control characters must be exactly one character") + if len(set(chars)) != len(chars): + raise ValueError("control characters must be distinct") + + +def encode_lossless_caps_v1(text: str, *, sentinel: str = DEFAULT_SENTINEL) -> str: + """Encode ASCII capitals reversibly using a one-character sentinel.""" + if len(sentinel) != 1: + raise ValueError("sentinel must be exactly one character") + out: list[str] = [] + for ch in text: + if ch == sentinel: + out.append(sentinel) + out.append(sentinel) + elif _is_ascii_upper(ch): + out.append(sentinel) + out.append(ch.lower()) + else: + out.append(ch) + return "".join(out) + + +def decode_lossless_caps_v1(text: str, *, sentinel: str = DEFAULT_SENTINEL) -> str: + """Decode the `lossless_caps_v1` transform back to the original text.""" + if len(sentinel) != 1: + raise ValueError("sentinel must be exactly one character") + out: list[str] = [] + i = 0 + n = len(text) + while i < n: + ch = text[i] + if ch != sentinel: + out.append(ch) + i += 1 + continue + if i + 1 >= n: + raise LosslessCapsError("dangling capitalization sentinel at end of string") + nxt = text[i + 1] + if nxt == sentinel: + out.append(sentinel) + elif _is_ascii_lower(nxt): + out.append(nxt.upper()) + else: + raise LosslessCapsError( + f"invalid sentinel escape sequence {sentinel + nxt!r}; " + "expected doubled sentinel or sentinel + lowercase ASCII letter" + ) + i += 2 + return "".join(out) + + +def encode_lossless_caps_v2( + text: str, + *, + title: str = DEFAULT_V2_TITLE, + allcaps: str = DEFAULT_V2_ALLCAPS, + capnext: str = DEFAULT_V2_CAPNEXT, + esc: str = DEFAULT_V2_ESC, +) -> str: + """Encode ASCII word capitalization with cheap word-level markers. + + Rules over maximal ASCII alphabetic runs: + - lowercase words stay unchanged + - TitleCase words become `title + lowercase(word)` + - ALLCAPS words become `allcaps + lowercase(word)` + - mixed-case words use: + - optional `title` when the first letter is uppercase + - `capnext + lowercase(letter)` for subsequent uppercase letters + - literal control characters are escaped as `esc + literal` + """ + _validate_distinct_single_chars(title, allcaps, capnext, esc) + controls = {title, allcaps, capnext, esc} + out: list[str] = [] + i = 0 + n = len(text) + while i < n: + ch = text[i] + if ch in controls: + out.append(esc) + out.append(ch) + i += 1 + continue + if not _is_ascii_alpha(ch): + out.append(ch) + i += 1 + continue + + j = i + 1 + while j < n and _is_ascii_alpha(text[j]): + j += 1 + word = text[i:j] + lower_word = word.lower() + + if word.islower(): + out.append(word) + elif len(word) >= 2 and word.isupper(): + out.append(allcaps) + out.append(lower_word) + elif _is_ascii_upper(word[0]) and word[1:].islower(): + out.append(title) + out.append(lower_word) + else: + if _is_ascii_upper(word[0]): + out.append(title) + out.append(lower_word[0]) + for orig_ch, lower_ch in zip(word[1:], lower_word[1:], strict=True): + if _is_ascii_upper(orig_ch): + out.append(capnext) + out.append(lower_ch) + i = j + return "".join(out) + + +def decode_lossless_caps_v2( + text: str, + *, + title: str = DEFAULT_V2_TITLE, + allcaps: str = DEFAULT_V2_ALLCAPS, + capnext: str = DEFAULT_V2_CAPNEXT, + esc: str = DEFAULT_V2_ESC, +) -> str: + """Decode the `lossless_caps_v2` transform back to the original text.""" + _validate_distinct_single_chars(title, allcaps, capnext, esc) + out: list[str] = [] + pending_escape = False + pending_word_mode: str | None = None + active_allcaps = False + pending_capnext = False + in_ascii_word = False + + for ch in text: + if pending_escape: + if pending_word_mode is not None and not _is_ascii_alpha(ch): + raise LosslessCapsError("escaped control char cannot satisfy pending word capitalization mode") + out.append(ch) + pending_escape = False + if _is_ascii_alpha(ch): + in_ascii_word = True + else: + in_ascii_word = False + active_allcaps = False + continue + + if ch == esc: + pending_escape = True + continue + if ch == title: + if pending_word_mode is not None or in_ascii_word or pending_capnext: + raise LosslessCapsError("invalid title marker placement") + pending_word_mode = "title" + continue + if ch == allcaps: + if pending_word_mode is not None or in_ascii_word or pending_capnext: + raise LosslessCapsError("invalid allcaps marker placement") + pending_word_mode = "allcaps" + continue + if ch == capnext: + if pending_capnext: + raise LosslessCapsError("duplicate capnext marker") + pending_capnext = True + continue + + if _is_ascii_alpha(ch): + at_word_start = not in_ascii_word + if at_word_start: + if pending_word_mode == "allcaps": + out.append(ch.upper()) + active_allcaps = True + elif pending_word_mode == "title": + out.append(ch.upper()) + elif pending_capnext: + out.append(ch.upper()) + else: + out.append(ch) + pending_word_mode = None + pending_capnext = False + in_ascii_word = True + continue + + if pending_word_mode is not None: + raise LosslessCapsError("word capitalization marker leaked into the middle of a word") + if active_allcaps: + out.append(ch.upper()) + elif pending_capnext: + out.append(ch.upper()) + else: + out.append(ch) + pending_capnext = False + continue + + if pending_word_mode is not None or pending_capnext: + raise LosslessCapsError("capitalization marker not followed by an ASCII letter") + out.append(ch) + in_ascii_word = False + active_allcaps = False + + if pending_escape: + raise LosslessCapsError("dangling escape marker at end of string") + if pending_word_mode is not None or pending_capnext: + raise LosslessCapsError("dangling capitalization marker at end of string") + return "".join(out) + + +def encode_lossless_caps_v3( + text: str, + *, + title: str = DEFAULT_V2_TITLE, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, +) -> str: + """Encode only common word-level capitalization patterns. + + Rules over maximal ASCII alphabetic runs: + - lowercase words stay unchanged + - TitleCase words become `title + lowercase(word)` + - ALLCAPS words become `allcaps + lowercase(word)` + - all other mixed-case words are left unchanged + - literal control characters are escaped as `esc + literal` + """ + _validate_distinct_single_chars(title, allcaps, esc) + controls = {title, allcaps, esc} + out: list[str] = [] + i = 0 + n = len(text) + while i < n: + ch = text[i] + if ch in controls: + out.append(esc) + out.append(ch) + i += 1 + continue + if not _is_ascii_alpha(ch): + out.append(ch) + i += 1 + continue + + j = i + 1 + while j < n and _is_ascii_alpha(text[j]): + j += 1 + word = text[i:j] + + if word.islower(): + out.append(word) + elif len(word) >= 2 and word.isupper(): + out.append(allcaps) + out.append(word.lower()) + elif _is_ascii_upper(word[0]) and word[1:].islower(): + out.append(title) + out.append(word.lower()) + else: + out.append(word) + i = j + return "".join(out) + + +def decode_lossless_caps_v3( + text: str, + *, + title: str = DEFAULT_V2_TITLE, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, +) -> str: + """Decode the `lossless_caps_v3` transform back to the original text.""" + _validate_distinct_single_chars(title, allcaps, esc) + out: list[str] = [] + pending_escape = False + pending_word_mode: str | None = None + active_allcaps = False + in_ascii_word = False + + for ch in text: + if pending_escape: + if pending_word_mode is not None and not _is_ascii_alpha(ch): + raise LosslessCapsError("escaped control char cannot satisfy pending word capitalization mode") + out.append(ch) + pending_escape = False + if _is_ascii_alpha(ch): + in_ascii_word = True + else: + in_ascii_word = False + active_allcaps = False + continue + + if ch == esc: + pending_escape = True + continue + if ch == title: + if pending_word_mode is not None or in_ascii_word: + raise LosslessCapsError("invalid title marker placement") + pending_word_mode = "title" + continue + if ch == allcaps: + if pending_word_mode is not None or in_ascii_word: + raise LosslessCapsError("invalid allcaps marker placement") + pending_word_mode = "allcaps" + continue + + if _is_ascii_alpha(ch): + at_word_start = not in_ascii_word + if at_word_start: + if pending_word_mode == "allcaps": + out.append(ch.upper()) + active_allcaps = True + elif pending_word_mode == "title": + out.append(ch.upper()) + else: + out.append(ch) + pending_word_mode = None + in_ascii_word = True + continue + + if pending_word_mode is not None: + raise LosslessCapsError("word capitalization marker leaked into the middle of a word") + out.append(ch.upper() if active_allcaps else ch) + continue + + if pending_word_mode is not None: + raise LosslessCapsError("capitalization marker not followed by an ASCII letter") + out.append(ch) + in_ascii_word = False + active_allcaps = False + + if pending_escape: + raise LosslessCapsError("dangling escape marker at end of string") + if pending_word_mode is not None: + raise LosslessCapsError("dangling capitalization marker at end of string") + return "".join(out) + + +def encode_lossless_caps_v4( + text: str, + *, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, +) -> str: + """Encode only ALLCAPS ASCII words, leaving all other case untouched.""" + _validate_distinct_single_chars(allcaps, esc) + controls = {allcaps, esc} + out: list[str] = [] + i = 0 + n = len(text) + while i < n: + ch = text[i] + if ch in controls: + out.append(esc) + out.append(ch) + i += 1 + continue + if not _is_ascii_alpha(ch): + out.append(ch) + i += 1 + continue + j = i + 1 + while j < n and _is_ascii_alpha(text[j]): + j += 1 + word = text[i:j] + if len(word) >= 2 and word.isupper(): + out.append(allcaps) + out.append(word.lower()) + else: + out.append(word) + i = j + return "".join(out) + + +def decode_lossless_caps_v4( + text: str, + *, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, +) -> str: + """Decode the `lossless_caps_v4` transform back to the original text.""" + _validate_distinct_single_chars(allcaps, esc) + out: list[str] = [] + pending_escape = False + pending_allcaps = False + in_ascii_word = False + active_allcaps = False + + for ch in text: + if pending_escape: + if pending_allcaps and not _is_ascii_alpha(ch): + raise LosslessCapsError("escaped control char cannot satisfy pending allcaps mode") + out.append(ch) + pending_escape = False + if _is_ascii_alpha(ch): + in_ascii_word = True + else: + in_ascii_word = False + active_allcaps = False + continue + + if ch == esc: + pending_escape = True + continue + if ch == allcaps: + if pending_allcaps or in_ascii_word: + raise LosslessCapsError("invalid allcaps marker placement") + pending_allcaps = True + continue + + if _is_ascii_alpha(ch): + if not in_ascii_word: + active_allcaps = pending_allcaps + pending_allcaps = False + in_ascii_word = True + out.append(ch.upper() if active_allcaps else ch) + continue + + if pending_allcaps: + raise LosslessCapsError("allcaps marker not followed by an ASCII letter") + out.append(ch) + in_ascii_word = False + active_allcaps = False + + if pending_escape: + raise LosslessCapsError("dangling escape marker at end of string") + if pending_allcaps: + raise LosslessCapsError("dangling allcaps marker at end of string") + return "".join(out) + + +def encode_lossless_caps_v5( + text: str, + *, + title: str = DEFAULT_V2_TITLE, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, + title_min_len: int = DEFAULT_V5_TITLE_MIN_LEN, +) -> str: + """Encode ALLCAPS words and only sufficiently long TitleCase words.""" + _validate_distinct_single_chars(title, allcaps, esc) + controls = {title, allcaps, esc} + out: list[str] = [] + i = 0 + n = len(text) + while i < n: + ch = text[i] + if ch in controls: + out.append(esc) + out.append(ch) + i += 1 + continue + if not _is_ascii_alpha(ch): + out.append(ch) + i += 1 + continue + j = i + 1 + while j < n and _is_ascii_alpha(text[j]): + j += 1 + word = text[i:j] + if len(word) >= 2 and word.isupper(): + out.append(allcaps) + out.append(word.lower()) + elif len(word) >= title_min_len and _is_ascii_upper(word[0]) and word[1:].islower(): + out.append(title) + out.append(word.lower()) + else: + out.append(word) + i = j + return "".join(out) + + +def decode_lossless_caps_v5( + text: str, + *, + title: str = DEFAULT_V2_TITLE, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, +) -> str: + """Decode the `lossless_caps_v5` transform back to the original text.""" + return decode_lossless_caps_v3(text, title=title, allcaps=allcaps, esc=esc) + + +def encode_lossless_caps_v6( + text: str, + *, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, + allcaps_min_len: int = DEFAULT_V6_ALLCAPS_MIN_LEN, +) -> str: + """Encode only ALLCAPS words with length >= allcaps_min_len.""" + _validate_distinct_single_chars(allcaps, esc) + controls = {allcaps, esc} + out: list[str] = [] + i = 0 + n = len(text) + while i < n: + ch = text[i] + if ch in controls: + out.append(esc) + out.append(ch) + i += 1 + continue + if not _is_ascii_alpha(ch): + out.append(ch) + i += 1 + continue + j = i + 1 + while j < n and _is_ascii_alpha(text[j]): + j += 1 + word = text[i:j] + if len(word) >= allcaps_min_len and word.isupper(): + out.append(allcaps) + out.append(word.lower()) + else: + out.append(word) + i = j + return "".join(out) + + +def decode_lossless_caps_v6( + text: str, + *, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, +) -> str: + """Decode the `lossless_caps_v6` transform back to the original text.""" + return decode_lossless_caps_v4(text, allcaps=allcaps, esc=esc) + + +def encode_lossless_caps_v7( + text: str, + *, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, + allcaps_min_len: int = DEFAULT_V7_ALLCAPS_MIN_LEN, +) -> str: + """Encode only ALLCAPS words with length >= 4.""" + return encode_lossless_caps_v6( + text, + allcaps=allcaps, + esc=esc, + allcaps_min_len=allcaps_min_len, + ) + + +def decode_lossless_caps_v7( + text: str, + *, + allcaps: str = DEFAULT_V2_ALLCAPS, + esc: str = DEFAULT_V2_ESC, +) -> str: + """Decode the `lossless_caps_v7` transform back to the original text.""" + return decode_lossless_caps_v6(text, allcaps=allcaps, esc=esc) + + +def get_text_transform(name: str | None) -> Callable[[str], str]: + """Return the forward text transform for the given config name.""" + normalized = IDENTITY if name in {None, "", IDENTITY} else str(name) + if normalized == IDENTITY: + return lambda text: text + if normalized == LOSSLESS_CAPS_V1: + return encode_lossless_caps_v1 + if normalized == LOSSLESS_CAPS_V2: + return encode_lossless_caps_v2 + if normalized == LOSSLESS_CAPS_V3: + return encode_lossless_caps_v3 + if normalized == LOSSLESS_CAPS_V4: + return encode_lossless_caps_v4 + if normalized == LOSSLESS_CAPS_V5: + return encode_lossless_caps_v5 + if normalized == LOSSLESS_CAPS_V6: + return encode_lossless_caps_v6 + if normalized == LOSSLESS_CAPS_V7: + return encode_lossless_caps_v7 + if normalized == LOSSLESS_CAPS_CASEOPS_V1: + return encode_lossless_caps_v2 + raise ValueError(f"unsupported text_transform={name!r}") + + +def get_text_inverse_transform(name: str | None) -> Callable[[str], str]: + """Return the inverse transform for the given config name.""" + normalized = IDENTITY if name in {None, "", IDENTITY} else str(name) + if normalized == IDENTITY: + return lambda text: text + if normalized == LOSSLESS_CAPS_V1: + return decode_lossless_caps_v1 + if normalized == LOSSLESS_CAPS_V2: + return decode_lossless_caps_v2 + if normalized == LOSSLESS_CAPS_V3: + return decode_lossless_caps_v3 + if normalized == LOSSLESS_CAPS_V4: + return decode_lossless_caps_v4 + if normalized == LOSSLESS_CAPS_V5: + return decode_lossless_caps_v5 + if normalized == LOSSLESS_CAPS_V6: + return decode_lossless_caps_v6 + if normalized == LOSSLESS_CAPS_V7: + return decode_lossless_caps_v7 + if normalized == LOSSLESS_CAPS_CASEOPS_V1: + return decode_lossless_caps_v2 + raise ValueError(f"unsupported text_transform={name!r}") + + +def normalize_text_transform_name(name: str | None) -> str: + """Normalize empty/None transform names to the identity transform.""" + return IDENTITY if name in {None, "", IDENTITY} else str(name) + + +def get_text_transform_control_symbols(name: str | None) -> list[str]: + """Return reserved control symbols used by a transform, if any.""" + normalized = normalize_text_transform_name(name) + if normalized == IDENTITY: + return [] + if normalized == LOSSLESS_CAPS_V1: + return [DEFAULT_SENTINEL] + if normalized == LOSSLESS_CAPS_V2: + return [DEFAULT_V2_TITLE, DEFAULT_V2_ALLCAPS, DEFAULT_V2_CAPNEXT, DEFAULT_V2_ESC] + if normalized == LOSSLESS_CAPS_CASEOPS_V1: + return [DEFAULT_V2_TITLE, DEFAULT_V2_ALLCAPS, DEFAULT_V2_CAPNEXT, DEFAULT_V2_ESC] + if normalized in {LOSSLESS_CAPS_V3, LOSSLESS_CAPS_V5}: + return [DEFAULT_V2_TITLE, DEFAULT_V2_ALLCAPS, DEFAULT_V2_ESC] + if normalized in {LOSSLESS_CAPS_V4, LOSSLESS_CAPS_V6, LOSSLESS_CAPS_V7}: + return [DEFAULT_V2_ALLCAPS, DEFAULT_V2_ESC] + raise ValueError(f"unsupported text_transform={name!r}") + + +def infer_text_transform_from_manifest(tokenizer_path: str | Path) -> str: + """Best-effort lookup of a tokenizer's text transform from a local manifest.""" + tokenizer_path = Path(tokenizer_path).expanduser().resolve() + manifest_candidates = [ + tokenizer_path.parent.parent / "manifest.json", + tokenizer_path.parent / "manifest.json", + ] + for manifest_path in manifest_candidates: + if not manifest_path.is_file(): + continue + try: + payload = json.loads(manifest_path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError): + continue + tokenizers = payload.get("tokenizers") + if not isinstance(tokenizers, list): + continue + for tokenizer_meta in tokenizers: + if not isinstance(tokenizer_meta, dict): + continue + model_path = tokenizer_meta.get("model_path") or tokenizer_meta.get("path") + if not model_path: + continue + candidate = (manifest_path.parent / str(model_path)).resolve() + if candidate == tokenizer_path: + return normalize_text_transform_name(tokenizer_meta.get("text_transform")) + return IDENTITY + + +def surface_piece_original_byte_counts( + surfaces: Iterable[str], + *, + text_transform_name: str | None = None, + sentinel: str = DEFAULT_SENTINEL, +) -> list[int]: + """Return exact original UTF-8 byte counts contributed by each surface piece. + + `surfaces` must be the exact decoded text fragments emitted by SentencePiece + in order, e.g. `piece.surface` from `encode_as_immutable_proto`. + """ + normalized = normalize_text_transform_name(text_transform_name) + if normalized == IDENTITY: + return [len(surface.encode("utf-8")) for surface in surfaces] + if normalized == LOSSLESS_CAPS_V1: + if len(sentinel) != 1: + raise ValueError("sentinel must be exactly one character") + sentinel_bytes = len(sentinel.encode("utf-8")) + pending_sentinel = False + counts: list[int] = [] + for surface in surfaces: + piece_bytes = 0 + for ch in surface: + if pending_sentinel: + if ch == sentinel: + piece_bytes += sentinel_bytes + elif _is_ascii_lower(ch): + piece_bytes += 1 + else: + raise LosslessCapsError( + f"invalid continuation {ch!r} after capitalization sentinel" + ) + pending_sentinel = False + continue + if ch == sentinel: + pending_sentinel = True + else: + piece_bytes += len(ch.encode("utf-8")) + counts.append(piece_bytes) + if pending_sentinel: + raise LosslessCapsError("dangling capitalization sentinel across piece boundary") + return counts + if normalized not in {LOSSLESS_CAPS_V2, LOSSLESS_CAPS_V3, LOSSLESS_CAPS_V4, LOSSLESS_CAPS_V5, LOSSLESS_CAPS_V6, LOSSLESS_CAPS_V7, LOSSLESS_CAPS_CASEOPS_V1}: + raise ValueError(f"unsupported text_transform={text_transform_name!r}") + + title = DEFAULT_V2_TITLE + allcaps = DEFAULT_V2_ALLCAPS + capnext = DEFAULT_V2_CAPNEXT + esc = DEFAULT_V2_ESC + if normalized in {LOSSLESS_CAPS_V2, LOSSLESS_CAPS_CASEOPS_V1}: + _validate_distinct_single_chars(title, allcaps, capnext, esc) + elif normalized in {LOSSLESS_CAPS_V4, LOSSLESS_CAPS_V6, LOSSLESS_CAPS_V7}: + _validate_distinct_single_chars(allcaps, esc) + else: + _validate_distinct_single_chars(title, allcaps, esc) + pending_escape = False + pending_word_mode: str | None = None + active_allcaps = False + pending_capnext = False + in_ascii_word = False + counts: list[int] = [] + for surface in surfaces: + piece_bytes = 0 + for ch in surface: + if pending_escape: + if pending_word_mode is not None and not _is_ascii_alpha(ch): + raise LosslessCapsError("escaped control char cannot satisfy pending word capitalization mode") + piece_bytes += len(ch.encode("utf-8")) + pending_escape = False + if _is_ascii_alpha(ch): + in_ascii_word = True + else: + in_ascii_word = False + active_allcaps = False + continue + if ch == esc: + pending_escape = True + continue + if normalized in {LOSSLESS_CAPS_V2, LOSSLESS_CAPS_V3, LOSSLESS_CAPS_V5, LOSSLESS_CAPS_CASEOPS_V1} and ch == title: + if pending_word_mode is not None or in_ascii_word or pending_capnext: + raise LosslessCapsError("invalid title marker placement") + pending_word_mode = "title" + continue + if ch == allcaps: + if pending_word_mode is not None or in_ascii_word or pending_capnext: + raise LosslessCapsError("invalid allcaps marker placement") + pending_word_mode = "allcaps" + continue + if normalized in {LOSSLESS_CAPS_V2, LOSSLESS_CAPS_CASEOPS_V1} and ch == capnext: + if pending_capnext: + raise LosslessCapsError("duplicate capnext marker") + pending_capnext = True + continue + + if _is_ascii_alpha(ch): + at_word_start = not in_ascii_word + if at_word_start: + piece_bytes += 1 + active_allcaps = pending_word_mode == "allcaps" + pending_word_mode = None + pending_capnext = False + in_ascii_word = True + continue + if pending_word_mode is not None: + raise LosslessCapsError("word capitalization marker leaked into the middle of a word") + piece_bytes += 1 + pending_capnext = False + continue + + if pending_word_mode is not None or pending_capnext: + raise LosslessCapsError("capitalization marker not followed by an ASCII letter") + piece_bytes += len(ch.encode("utf-8")) + in_ascii_word = False + active_allcaps = False + counts.append(piece_bytes) + if pending_escape: + raise LosslessCapsError("dangling escape marker across piece boundary") + if pending_word_mode is not None or pending_capnext: + raise LosslessCapsError("dangling capitalization marker across piece boundary") + return counts diff --git a/records/track_10min_16mb/2026-05-01_SP8192_PR2130Base_GW8/online_ngram_state.c b/records/track_10min_16mb/2026-05-01_SP8192_PR2130Base_GW8/online_ngram_state.c new file mode 100644 index 0000000000..f8472a6f05 --- /dev/null +++ b/records/track_10min_16mb/2026-05-01_SP8192_PR2130Base_GW8/online_ngram_state.c @@ -0,0 +1,433 @@ +#include +#include +#include + +#define COEFF_COUNT 32 + +static const uint64_t ROLLING_COEFFS[COEFF_COUNT] = { + 36313ULL, 27191ULL, 51647ULL, 81929ULL, 131071ULL, 196613ULL, + 262147ULL, 393241ULL, 524309ULL, 655373ULL, 786433ULL, 917521ULL, + 1048583ULL, 1179653ULL, 1310729ULL, 1441801ULL, 1572869ULL, 1703941ULL, + 1835017ULL, 1966087ULL, 2097169ULL, 2228243ULL, 2359319ULL, 2490389ULL, + 2621471ULL, 2752549ULL, 2883617ULL, 3014687ULL, 3145757ULL, 3276833ULL, + 3407903ULL, 3538973ULL, +}; + +static const uint64_t PAIR_MIX = 1000003ULL; +static const uint64_t PREFIX_BASE = 1099511628211ULL; +static const uint64_t LEN_MIX = 0x9E3779B185EBCA87ULL; +static const uint64_t TABLE_MIX = 0x9e3779b97f4a7c15ULL; + +typedef struct { + uint64_t key; + uint32_t total; + uint32_t top_count; + uint16_t top_tok; + uint16_t _pad; +} CtxBucket; + +typedef struct { + uint64_t key; + uint32_t count; + uint32_t _pad; +} PairBucket; + +typedef struct { + int token_ctx_len; + int token_prefix_len; + int token_head; + uint16_t *token_ring; + + CtxBucket *token_ctx_tbl; + uint8_t *token_ctx_used; + size_t token_ctx_mask; + + PairBucket *token_pair_tbl; + uint8_t *token_pair_used; + size_t token_pair_mask; + + uint64_t within_hash; + uint32_t within_len; + + CtxBucket *within_ctx_tbl; + uint8_t *within_ctx_used; + size_t within_ctx_mask; + + PairBucket *within_pair_tbl; + uint8_t *within_pair_used; + size_t within_pair_mask; +} OnlineNgramState; + +static inline size_t mix_index(uint64_t key, size_t mask) { + return (size_t)((key * TABLE_MIX) & mask); +} + +static inline size_t find_ctx_slot( + CtxBucket *tbl, + uint8_t *used, + size_t mask, + uint64_t key, + int *found +) { + size_t idx = mix_index(key, mask); + for (size_t probe = 0; probe <= mask; ++probe) { + if (!used[idx]) { + *found = 0; + return idx; + } + if (tbl[idx].key == key) { + *found = 1; + return idx; + } + idx = (idx + 1U) & mask; + } + *found = -1; + return 0; +} + +static inline size_t find_pair_slot( + PairBucket *tbl, + uint8_t *used, + size_t mask, + uint64_t key, + int *found +) { + size_t idx = mix_index(key, mask); + for (size_t probe = 0; probe <= mask; ++probe) { + if (!used[idx]) { + *found = 0; + return idx; + } + if (tbl[idx].key == key) { + *found = 1; + return idx; + } + idx = (idx + 1U) & mask; + } + *found = -1; + return 0; +} + +static inline uint64_t token_pair_key(uint64_t ctx_key, uint16_t tok, int ctx_len) { + return (ctx_key * PAIR_MIX) ^ (((uint64_t)tok) * ROLLING_COEFFS[(size_t)ctx_len % COEFF_COUNT]); +} + +static inline uint64_t within_pair_key(uint64_t ctx_key, uint16_t tok) { + return (ctx_key * PAIR_MIX) ^ (((uint64_t)tok) * ROLLING_COEFFS[0]); +} + +static inline uint64_t extend_prefix_hash(uint64_t current_hash, uint16_t tok, uint32_t pos) { + return (current_hash * PREFIX_BASE) ^ (((uint64_t)tok + 1ULL) * ROLLING_COEFFS[(size_t)pos % COEFF_COUNT]); +} + +static inline uint32_t pair_increment( + PairBucket *tbl, + uint8_t *used, + size_t mask, + uint64_t key +) { + int found = 0; + size_t idx = find_pair_slot(tbl, used, mask, key, &found); + if (found < 0) { + return 0U; + } + if (!found) { + used[idx] = 1U; + tbl[idx].key = key; + tbl[idx].count = 1U; + return 1U; + } + tbl[idx].count += 1U; + return tbl[idx].count; +} + +static inline int ctx_increment( + CtxBucket *tbl, + uint8_t *used, + size_t mask, + uint64_t key, + uint16_t tok, + uint32_t pair_count +) { + int found = 0; + size_t idx = find_ctx_slot(tbl, used, mask, key, &found); + if (found < 0) { + return -1; + } + if (!found) { + used[idx] = 1U; + tbl[idx].key = key; + tbl[idx].total = 1U; + tbl[idx].top_count = pair_count; + tbl[idx].top_tok = tok; + return 0; + } + tbl[idx].total += 1U; + if (pair_count > tbl[idx].top_count) { + tbl[idx].top_count = pair_count; + tbl[idx].top_tok = tok; + } + return 0; +} + +static inline uint64_t token_context_hash(const OnlineNgramState *st) { + uint64_t h = 0ULL; + if (st->token_ctx_len <= 0) { + return h; + } + for (int j = 0; j < st->token_ctx_len; ++j) { + const int ring_idx = (st->token_head + j) % st->token_ctx_len; + h ^= ((uint64_t)st->token_ring[ring_idx]) * ROLLING_COEFFS[(size_t)j]; + } + return h; +} + +static inline void token_push(OnlineNgramState *st, uint16_t tok) { + if (st->token_ctx_len <= 0) { + return; + } + if (st->token_prefix_len < st->token_ctx_len) { + st->token_ring[st->token_prefix_len] = tok; + st->token_prefix_len += 1; + return; + } + st->token_ring[st->token_head] = tok; + st->token_head = (st->token_head + 1) % st->token_ctx_len; +} + +static void *xcalloc(size_t count, size_t size) { + if (count == 0 || size == 0) { + return NULL; + } + return calloc(count, size); +} + +static int alloc_tables( + size_t table_bits, + CtxBucket **ctx_tbl, + uint8_t **ctx_used, + size_t *ctx_mask, + PairBucket **pair_tbl, + uint8_t **pair_used, + size_t *pair_mask +) { + const size_t size = 1ULL << table_bits; + *ctx_tbl = (CtxBucket *)xcalloc(size, sizeof(CtxBucket)); + *ctx_used = (uint8_t *)xcalloc(size, sizeof(uint8_t)); + *pair_tbl = (PairBucket *)xcalloc(size, sizeof(PairBucket)); + *pair_used = (uint8_t *)xcalloc(size, sizeof(uint8_t)); + if (!*ctx_tbl || !*ctx_used || !*pair_tbl || !*pair_used) { + return -1; + } + *ctx_mask = size - 1U; + *pair_mask = size - 1U; + return 0; +} + +void *online_ngram_state_create( + int token_ctx_len, + int token_table_bits, + int within_table_bits +) { + if (token_ctx_len < 0 || token_table_bits <= 0 || within_table_bits <= 0) { + return NULL; + } + OnlineNgramState *st = (OnlineNgramState *)calloc(1, sizeof(OnlineNgramState)); + if (!st) { + return NULL; + } + st->token_ctx_len = token_ctx_len; + if (token_ctx_len > 0) { + st->token_ring = (uint16_t *)xcalloc((size_t)token_ctx_len, sizeof(uint16_t)); + if (!st->token_ring) { + free(st); + return NULL; + } + } + if (alloc_tables( + (size_t)token_table_bits, + &st->token_ctx_tbl, + &st->token_ctx_used, + &st->token_ctx_mask, + &st->token_pair_tbl, + &st->token_pair_used, + &st->token_pair_mask + ) != 0) { + free(st->token_ring); + free(st); + return NULL; + } + if (alloc_tables( + (size_t)within_table_bits, + &st->within_ctx_tbl, + &st->within_ctx_used, + &st->within_ctx_mask, + &st->within_pair_tbl, + &st->within_pair_used, + &st->within_pair_mask + ) != 0) { + free(st->token_pair_used); + free(st->token_pair_tbl); + free(st->token_ctx_used); + free(st->token_ctx_tbl); + free(st->token_ring); + free(st); + return NULL; + } + return (void *)st; +} + +void online_ngram_state_destroy(void *ptr) { + OnlineNgramState *st = (OnlineNgramState *)ptr; + if (!st) { + return; + } + free(st->within_pair_used); + free(st->within_pair_tbl); + free(st->within_ctx_used); + free(st->within_ctx_tbl); + free(st->token_pair_used); + free(st->token_pair_tbl); + free(st->token_ctx_used); + free(st->token_ctx_tbl); + free(st->token_ring); + free(st); +} + +void online_ngram_state_seed_prefix_token(void *ptr, uint16_t tok) { + OnlineNgramState *st = (OnlineNgramState *)ptr; + if (!st) { + return; + } + token_push(st, tok); +} + +int online_ngram_state_process_chunk( + void *ptr, + const uint16_t *tokens, + int64_t n_tokens, + const uint8_t *starts_new_word_lut, + const uint8_t *boundary_lut, + uint16_t *token_top_token, + float *token_top_prob, + uint16_t *within_top_token, + float *within_top_prob, + uint8_t *within_valid +) { + OnlineNgramState *st = (OnlineNgramState *)ptr; + if (!st || !tokens || n_tokens < 0) { + return -1; + } + for (int64_t i = 0; i < n_tokens; ++i) { + const uint16_t tok = tokens[i]; + const uint8_t is_boundary = boundary_lut[tok]; + const uint8_t is_new_word = starts_new_word_lut[tok]; + + uint64_t token_ctx_key = 0ULL; + if (st->token_ctx_len == 0 || st->token_prefix_len >= st->token_ctx_len) { + token_ctx_key = token_context_hash(st); + int found = 0; + size_t idx = find_ctx_slot( + st->token_ctx_tbl, + st->token_ctx_used, + st->token_ctx_mask, + token_ctx_key, + &found + ); + if (found > 0) { + token_top_token[i] = st->token_ctx_tbl[idx].top_tok; + token_top_prob[i] = + (float)st->token_ctx_tbl[idx].top_count / (float)st->token_ctx_tbl[idx].total; + } else { + token_top_token[i] = 0U; + token_top_prob[i] = 0.0f; + } + } else { + token_top_token[i] = 0U; + token_top_prob[i] = 0.0f; + } + + uint64_t within_ctx_key = 0ULL; + if (!is_boundary && !is_new_word && st->within_len > 0U) { + within_ctx_key = st->within_hash ^ ((uint64_t)st->within_len * LEN_MIX); + int found = 0; + size_t idx = find_ctx_slot( + st->within_ctx_tbl, + st->within_ctx_used, + st->within_ctx_mask, + within_ctx_key, + &found + ); + within_valid[i] = 1U; + if (found > 0) { + within_top_token[i] = st->within_ctx_tbl[idx].top_tok; + within_top_prob[i] = + (float)st->within_ctx_tbl[idx].top_count / (float)st->within_ctx_tbl[idx].total; + } else { + within_top_token[i] = 0U; + within_top_prob[i] = 0.0f; + } + } else { + within_valid[i] = 0U; + within_top_token[i] = 0U; + within_top_prob[i] = 0.0f; + } + + if (st->token_ctx_len == 0 || st->token_prefix_len >= st->token_ctx_len) { + const uint64_t pair_key = token_pair_key(token_ctx_key, tok, st->token_ctx_len); + const uint32_t pair_count = pair_increment( + st->token_pair_tbl, + st->token_pair_used, + st->token_pair_mask, + pair_key + ); + if (pair_count == 0U) { + return -2; + } + if (ctx_increment( + st->token_ctx_tbl, + st->token_ctx_used, + st->token_ctx_mask, + token_ctx_key, + tok, + pair_count + ) != 0) { + return -3; + } + } + token_push(st, tok); + + if (is_boundary) { + st->within_hash = 0ULL; + st->within_len = 0U; + continue; + } + if (is_new_word || st->within_len == 0U) { + st->within_hash = extend_prefix_hash(0ULL, tok, 0U); + st->within_len = 1U; + continue; + } + const uint32_t within_pair_count = pair_increment( + st->within_pair_tbl, + st->within_pair_used, + st->within_pair_mask, + within_pair_key(within_ctx_key, tok) + ); + if (within_pair_count == 0U) { + return -4; + } + if (ctx_increment( + st->within_ctx_tbl, + st->within_ctx_used, + st->within_ctx_mask, + within_ctx_key, + tok, + within_pair_count + ) != 0) { + return -5; + } + st->within_hash = extend_prefix_hash(st->within_hash, tok, st->within_len); + st->within_len += 1U; + } + return 0; +} diff --git a/records/track_10min_16mb/2026-05-01_SP8192_PR2130Base_GW8/online_ngram_tilt.py b/records/track_10min_16mb/2026-05-01_SP8192_PR2130Base_GW8/online_ngram_tilt.py new file mode 100644 index 0000000000..973c21866f --- /dev/null +++ b/records/track_10min_16mb/2026-05-01_SP8192_PR2130Base_GW8/online_ngram_tilt.py @@ -0,0 +1,386 @@ +""" +Vendored online n-gram tilt helpers from PR #1145 (AnirudhRahul, valerio-endorsed). + +Provides causal, normalized, prefix-only n-gram experts that propose at most one +hinted token per scored position. Caller obtains q_t = p(h_t | x) from the model +(post-TTT-adapt logits) and applies multiplicative-boost-with-renorm: + + p'(a) = exp(beta * 1[a == h_t]) * p(a) / Z_t + Z_t = 1 - q_t + exp(beta) * q_t = 1 + q_t * (exp(beta) - 1) + -log p'(y_realized) = -log p(y) - beta * 1[y == h_t] + log Z_t + = ptl - beta * is_hit + log1p(q_t * (exp(beta) - 1)) + +Compliance: +- C1 causal: hint h_t computed from strict prefix (tokens 0..t-1 only) +- C2 normalized over Sigma: closed-form Z_t over full vocab softmax +- C3 score-before-update: hints precomputed in single L->R pass; loss uses prefix-only +- C4 single pass: process_chunk advances state monotonically + +Compatible with both #1934/#1855 base architectures via Hyperparameter env-var gates. +""" + +from __future__ import annotations + +import ctypes +import math +import os +import subprocess +from collections import deque +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch + + +SCRIPT_DIR = Path(__file__).resolve().parent +ONLINE_NGRAM_SRC = SCRIPT_DIR / "online_ngram_state.c" +ONLINE_NGRAM_LIB = SCRIPT_DIR / "libonline_ngram_state.so" + +WHITESPACE_BYTE_IDS = {9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 36} +EDGE_PUNCT = ".,:;!?()[]{}<>\"'`" + + +def normalize_word(text: str, mode: str) -> str: + text = text.strip() + if mode == "lower": + return text.lower() + if mode == "identity": + return text + if mode == "strip_punct_lower": + return text.strip(EDGE_PUNCT).lower() + raise ValueError(f"Unknown word normalization mode: {mode}") + + +def suggest_table_bits(expected_entries: int, load_factor: float) -> int: + if expected_entries <= 0: + return 16 + target = max(int(expected_entries / max(load_factor, 1e-6)), 1) + bits = max(int(math.ceil(math.log2(target))), 12) + return min(bits, 28) + + +def ensure_online_ngram_lib(log0=print) -> ctypes.CDLL: + needs_build = (not ONLINE_NGRAM_LIB.exists()) or ( + ONLINE_NGRAM_SRC.stat().st_mtime_ns > ONLINE_NGRAM_LIB.stat().st_mtime_ns + ) + if needs_build: + log0(f"ngram_tilt:building_native_helper src={ONLINE_NGRAM_SRC.name}") + subprocess.run( + [ + "gcc", "-O3", "-march=native", "-shared", "-fPIC", + "-o", str(ONLINE_NGRAM_LIB), + str(ONLINE_NGRAM_SRC), + ], + check=True, + ) + lib = ctypes.CDLL(str(ONLINE_NGRAM_LIB)) + lib.online_ngram_state_create.restype = ctypes.c_void_p + lib.online_ngram_state_create.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int] + lib.online_ngram_state_destroy.restype = None + lib.online_ngram_state_destroy.argtypes = [ctypes.c_void_p] + lib.online_ngram_state_seed_prefix_token.restype = None + lib.online_ngram_state_seed_prefix_token.argtypes = [ctypes.c_void_p, ctypes.c_uint16] + lib.online_ngram_state_process_chunk.restype = ctypes.c_int + lib.online_ngram_state_process_chunk.argtypes = [ + ctypes.c_void_p, + ctypes.POINTER(ctypes.c_uint16), + ctypes.c_int64, + ctypes.POINTER(ctypes.c_uint8), + ctypes.POINTER(ctypes.c_uint8), + ctypes.POINTER(ctypes.c_uint16), + ctypes.POINTER(ctypes.c_float), + ctypes.POINTER(ctypes.c_uint16), + ctypes.POINTER(ctypes.c_float), + ctypes.POINTER(ctypes.c_uint8), + ] + return lib + + +class OnlineNgramState: + def __init__( + self, *, lib, token_ctx_len, token_table_bits, within_table_bits, + starts_new_word_lut, boundary_lut, seed_prefix_token, + ): + self.lib = lib + self.state = lib.online_ngram_state_create(token_ctx_len, token_table_bits, within_table_bits) + if not self.state: + raise RuntimeError( + f"Native ngram state alloc failed token_table_bits={token_table_bits} within_table_bits={within_table_bits}" + ) + self.starts_new_word_lut = np.ascontiguousarray(starts_new_word_lut.astype(np.uint8, copy=False)) + self.boundary_lut = np.ascontiguousarray(boundary_lut.astype(np.uint8, copy=False)) + self.lib.online_ngram_state_seed_prefix_token(self.state, ctypes.c_uint16(int(seed_prefix_token))) + + def close(self): + if self.state: + self.lib.online_ngram_state_destroy(self.state) + self.state = None + + def __del__(self): + self.close() + + def process_chunk(self, chunk_tokens): + chunk_tokens = np.ascontiguousarray(chunk_tokens.astype(np.uint16, copy=False)) + n = int(chunk_tokens.size) + token_top_token = np.zeros(n, dtype=np.uint16) + token_top_prob = np.zeros(n, dtype=np.float32) + within_top_token = np.zeros(n, dtype=np.uint16) + within_top_prob = np.zeros(n, dtype=np.float32) + within_valid = np.zeros(n, dtype=np.uint8) + rc = self.lib.online_ngram_state_process_chunk( + self.state, + chunk_tokens.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), + ctypes.c_int64(n), + self.starts_new_word_lut.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)), + self.boundary_lut.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)), + token_top_token.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), + token_top_prob.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), + within_top_token.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), + within_top_prob.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), + within_valid.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)), + ) + if rc != 0: + raise RuntimeError(f"Native ngram process_chunk failed rc={rc}") + return token_top_token, token_top_prob, within_top_token, within_top_prob, within_valid.astype(bool) + + +class WordStartState: + def __init__(self, *, sp, order, normalize_mode): + self.sp = sp + self.ctx_w = max(order - 1, 0) + self.normalize_mode = normalize_mode + self.prev_word_ids: deque = deque(maxlen=self.ctx_w) + self.current_word_tokens: list = [] + self.word_to_id: dict = {} + self.next_word_id = 1 + self.ctx_total: dict = {} + self.pair_count: dict = {} + self.ctx_best_token: dict = {} + self.ctx_best_count: dict = {} + + def _flush_current_word(self): + if not self.current_word_tokens: + return + text = normalize_word(self.sp.decode(self.current_word_tokens), self.normalize_mode) + if text: + wid = self.word_to_id.get(text) + if wid is None: + wid = self.next_word_id + self.word_to_id[text] = wid + self.next_word_id += 1 + if self.ctx_w > 0: + self.prev_word_ids.append(wid) + self.current_word_tokens = [] + + def process_chunk(self, chunk_tokens, *, starts_new_word_lut, boundary_lut): + chunk_tokens = np.ascontiguousarray(chunk_tokens.astype(np.uint16, copy=False)) + top_token = np.zeros(chunk_tokens.size, dtype=np.uint16) + top_prob = np.zeros(chunk_tokens.size, dtype=np.float32) + for i, tok_u16 in enumerate(chunk_tokens): + tok = int(tok_u16) + is_boundary = bool(boundary_lut[tok]) + is_word_start = bool(starts_new_word_lut[tok]) or not self.current_word_tokens + if is_boundary: + self._flush_current_word() + continue + if bool(starts_new_word_lut[tok]): + self._flush_current_word() + ctx_key = None + if is_word_start and len(self.prev_word_ids) >= self.ctx_w: + ctx_key = tuple(self.prev_word_ids) if self.ctx_w > 0 else () + total = self.ctx_total.get(ctx_key, 0) + if total > 0: + top_token[i] = np.uint16(self.ctx_best_token[ctx_key]) + top_prob[i] = np.float32(self.ctx_best_count[ctx_key] / total) + if is_word_start: + if ctx_key is not None: + pair_key = (ctx_key, tok) + pair = self.pair_count.get(pair_key, 0) + 1 + self.pair_count[pair_key] = pair + total = self.ctx_total.get(ctx_key, 0) + 1 + self.ctx_total[ctx_key] = total + best_count = self.ctx_best_count.get(ctx_key, 0) + if pair > best_count: + self.ctx_best_count[ctx_key] = pair + self.ctx_best_token[ctx_key] = tok + self.current_word_tokens = [tok] + else: + self.current_word_tokens.append(tok) + return top_token, top_prob + + +def build_piece_luts(*, tokenizer_path, vocab_size): + sp = spm.SentencePieceProcessor(model_file=tokenizer_path) + pieces = [sp.id_to_piece(i) for i in range(sp.vocab_size())] + starts_new_word_lut = np.zeros(vocab_size, dtype=np.uint8) + for i, piece in enumerate(pieces): + starts_new_word_lut[i] = 1 if piece.startswith("▁") else 0 + boundary_lut = np.zeros(vocab_size, dtype=np.uint8) + bos_id = sp.bos_id() + if bos_id >= 0 and bos_id < vocab_size: + boundary_lut[bos_id] = 1 + for tok in range(min(sp.vocab_size(), vocab_size)): + if sp.is_byte(tok) and tok in WHITESPACE_BYTE_IDS: + boundary_lut[tok] = 1 + return sp, starts_new_word_lut, boundary_lut + + +def build_hints_for_targets( + *, target_token_ids_np, tokenizer_path, vocab_size, log0=print, + token_order=16, token_threshold=0.800, token_boost=2.625, + within_tau=0.450, within_boost=0.750, + word_order=4, word_normalize="strip_punct_lower", + word_tau=0.650, word_boost=0.750, + agree_add_boost=0.500, +): + """Single L->R pass. Returns dict with hint_ids, gate_mask, boost_per_pos. + + target_token_ids_np: np.uint16 array of realized targets (length = total_targets). + Output arrays are aligned to target_token_ids_np indexing. + + For each scored position t we pick at most one hint h_t: + - prefer the expert with highest expected gain = p_top * boost - log1p(p_top * (exp(boost)-1)) + - if multiple experts agree on the same h_t, additive boost agree_add_boost + - gate (don't tilt) when no expert clears its threshold + + The realized loss formula used by the caller: + ptl' = ptl - beta * 1[y == h_t] + log1p(q_t * (exp(beta) - 1)) when gate_mask == True + ptl' = ptl when gate_mask == False + """ + sp, starts_new_word_lut, boundary_lut = build_piece_luts( + tokenizer_path=tokenizer_path, vocab_size=vocab_size + ) + total = int(target_token_ids_np.size) + if total == 0: + return { + "hint_ids": np.zeros(0, dtype=np.int64), + "gate_mask": np.zeros(0, dtype=bool), + "boost": np.zeros(0, dtype=np.float32), + "sp": sp, + "starts_new_word_lut": starts_new_word_lut, + "boundary_lut": boundary_lut, + } + + token_table_bits = suggest_table_bits(total, load_factor=0.55) + within_table_bits = suggest_table_bits(max(total // 2, 1), load_factor=0.60) + online_lib = ensure_online_ngram_lib(log0) + ngram_state = OnlineNgramState( + lib=online_lib, + token_ctx_len=max(token_order - 1, 0), + token_table_bits=token_table_bits, + within_table_bits=within_table_bits, + starts_new_word_lut=starts_new_word_lut, + boundary_lut=boundary_lut, + seed_prefix_token=int(target_token_ids_np[0]), + ) + word_state = WordStartState(sp=sp, order=word_order, normalize_mode=word_normalize) + + token_top_tok, token_top_prob, within_top_tok, within_top_prob, within_valid = ( + ngram_state.process_chunk(target_token_ids_np) + ) + word_top_tok, word_top_prob = word_state.process_chunk( + target_token_ids_np, + starts_new_word_lut=starts_new_word_lut, + boundary_lut=boundary_lut, + ) + + def _expected_gain(p_top, boost): + # E[ -log p'(y) under -log p(y)] when y ~ p + # = p_top * boost - log1p(p_top * (exp(boost) - 1)) + # Maximizing this over experts => pick the most informative hint. + log_norm = np.log1p(p_top * (math.exp(boost) - 1.0)) + return p_top * boost - log_norm + + token_gate = token_top_prob >= np.float32(token_threshold) + within_gate = within_valid & (within_top_prob >= np.float32(within_tau)) + word_gate = word_top_prob >= np.float32(word_tau) + + token_gain = np.where(token_gate, _expected_gain(token_top_prob.astype(np.float64), token_boost), -np.inf) + within_gain = np.where(within_gate, _expected_gain(within_top_prob.astype(np.float64), within_boost), -np.inf) + word_gain = np.where(word_gate, _expected_gain(word_top_prob.astype(np.float64), word_boost), -np.inf) + + stack = np.stack([token_gain, within_gain, word_gain], axis=1) + best_idx = np.argmax(stack, axis=1) + best_gain = np.max(stack, axis=1) + any_gate = best_gain > -np.inf + + hint_ids = np.zeros(total, dtype=np.int64) + boost = np.zeros(total, dtype=np.float32) + base_boost_per_expert = np.array([token_boost, within_boost, word_boost], dtype=np.float32) + hint_per_expert = np.stack([ + token_top_tok.astype(np.int64), + within_top_tok.astype(np.int64), + word_top_tok.astype(np.int64), + ], axis=1) + + rows = np.arange(total) + hint_ids[any_gate] = hint_per_expert[rows[any_gate], best_idx[any_gate]] + boost[any_gate] = base_boost_per_expert[best_idx[any_gate]] + + # Agreement bonus: if 2+ experts agree on the same hint as best, add agree_add_boost + gate_mask_each = np.stack([token_gate, within_gate, word_gate], axis=1) + expert_hints = hint_per_expert.copy() + expert_hints[~gate_mask_each] = -1 + agreements = (expert_hints == hint_ids[:, None]).sum(axis=1) + agreement_extra = np.where(agreements >= 2, np.float32(agree_add_boost), np.float32(0.0)) + boost = (boost + agreement_extra).astype(np.float32) + + log0( + f"ngram_tilt:hints total={total} gated={int(any_gate.sum())} " + f"token_gate={int(token_gate.sum())} within_gate={int(within_gate.sum())} word_gate={int(word_gate.sum())} " + f"agree2plus={int((agreements >= 2).sum())}" + ) + + return { + "hint_ids": hint_ids, + "gate_mask": any_gate, + "boost": boost, + "sp": sp, + "starts_new_word_lut": starts_new_word_lut, + "boundary_lut": boundary_lut, + } + + +def apply_tilt_to_ptl_torch( + ptl: torch.Tensor, + log_q_hint: torch.Tensor, + target_ids: torch.Tensor, + hint_ids: torch.Tensor, + gate_mask: torch.Tensor, + boost: torch.Tensor, +): + """Closed-form tilt applied to per-token NLL. + + All tensors same shape [..., L]. + ptl_tilted = ptl - beta * 1[y == h] + log1p(q * (exp(beta) - 1)) if gate else ptl + """ + boost64 = boost.to(torch.float64) + q = log_q_hint.to(torch.float64).clamp_(max=0.0).exp() + is_hit = (target_ids == hint_ids).to(torch.float64) + log_Z = torch.log1p(q * (torch.expm1(boost64))) + ptl_tilted = ptl.to(torch.float64) - boost64 * is_hit + log_Z + return torch.where(gate_mask, ptl_tilted, ptl.to(torch.float64)).to(ptl.dtype) + + +def apply_tilt_to_ptl_torch_fast( + ptl: torch.Tensor, + log_q_hint: torch.Tensor, + target_ids: torch.Tensor, + hint_ids: torch.Tensor, + gate_mask: torch.Tensor, + boost: torch.Tensor, +): + """fp32 variant of apply_tilt — cast removed where safe. + + BPB downstream accumulator is fp64, so per-token tilt computation in + fp32 has no impact on final precision. Saves ~10-15s per eval pass on + H100 (avoids fp64 ALU + double memory traffic). + """ + boost32 = boost.to(torch.float32) + q = log_q_hint.to(torch.float32).clamp_(max=0.0).exp() + is_hit = (target_ids == hint_ids).to(torch.float32) + log_Z = torch.log1p(q * (torch.expm1(boost32))) + ptl_f32 = ptl.to(torch.float32) + ptl_tilted = ptl_f32 - boost32 * is_hit + log_Z + return torch.where(gate_mask, ptl_tilted, ptl_f32).to(ptl.dtype) diff --git a/records/track_10min_16mb/2026-05-01_SP8192_PR2130Base_GW8/prepare_caseops_data.py b/records/track_10min_16mb/2026-05-01_SP8192_PR2130Base_GW8/prepare_caseops_data.py new file mode 100644 index 0000000000..e3a6123b68 --- /dev/null +++ b/records/track_10min_16mb/2026-05-01_SP8192_PR2130Base_GW8/prepare_caseops_data.py @@ -0,0 +1,194 @@ +"""Prepare CaseOps-tokenized FineWeb shards + per-token byte sidecar. + +CaseOps (``lossless_caps_caseops_v1``) is a bijective, character-level text +transform that introduces four operator tokens in place of explicit +capitalization: TITLE, ALLCAPS, CAPNEXT, ESC. The transform is fully +reversible — no information is lost relative to the untransformed UTF-8 +text, so BPB stays computable on TRUE byte counts. + +Forward pipeline: + 1. Read the canonical FineWeb-10B doc stream (``docs_selected.jsonl`` + produced by ``data/download_hf_docs_and_tokenize.py`` in the root repo). + 2. Apply ``encode_lossless_caps_v2`` (the caseops_v1 alias) to each doc. + 3. Tokenize with the shipped SP model + ``tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model`` + (reserves TITLE/ALLCAPS/CAPNEXT/ESC + sentinel as user_defined_symbols). + 4. Write uint16 train/val shards (``fineweb_{train,val}_XXXXXX.bin``). + 5. For the VAL stream only, emit per-token byte sidecar shards + (``fineweb_val_bytes_XXXXXX.bin``, uint16 parallel arrays) that record + each token's ORIGINAL pre-transform UTF-8 byte count. BPB is computed + from these canonical bytes so the score is on the untransformed text + (not the transformed representation). + +Output layout — matches what ``train_gpt.py`` expects under +``DATA_DIR=./data`` with ``CASEOPS_ENABLED=1``: + + data/datasets/fineweb10B_sp8192_caseops/datasets/ + tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model + datasets/fineweb10B_sp8192_lossless_caps_caseops_v1_reserved/ + fineweb_train_000000.bin + fineweb_train_000001.bin + ... + fineweb_val_000000.bin + fineweb_val_bytes_000000.bin + +Usage: + + python3 prepare_caseops_data.py \\ + --docs ./fineweb10B_raw/docs_selected.jsonl \\ + --out ./data/datasets/fineweb10B_sp8192_caseops/datasets \\ + --sp ./tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model + +Requirements: sentencepiece, numpy. CPU-only. Runs once; reused across seeds. +""" +from __future__ import annotations + +import argparse +import json +import pathlib +import struct +import sys + +import numpy as np +import sentencepiece as spm + +# Local import — lossless_caps.py ships next to this script. +sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent)) +from lossless_caps import ( # noqa: E402 + LOSSLESS_CAPS_CASEOPS_V1, + encode_lossless_caps_v2, + surface_piece_original_byte_counts, +) + + +SHARD_MAGIC = 20240520 +SHARD_VERSION = 1 +SHARD_TOKENS = 10_000_000 # tokens per shard — matches the main pipeline +BOS_ID = 1 # SP model's control token; train_gpt.py:_find_docs requires BOS per doc + + +def _write_shard(out_path: pathlib.Path, arr: np.ndarray) -> None: + """Write a uint16 shard in the standard header-prefixed format.""" + assert arr.dtype == np.uint16 + header = np.zeros(256, dtype=np.int32) + header[0] = SHARD_MAGIC + header[1] = SHARD_VERSION + header[2] = int(arr.size) + with out_path.open("wb") as fh: + fh.write(header.tobytes()) + fh.write(arr.tobytes()) + + +def _iter_docs(docs_path: pathlib.Path, skip: int = 0): + """Yield doc strings from a jsonl file (one json object per line). + + If skip > 0, the first ``skip`` non-empty lines are discarded before + yielding begins. Use this to continue a prep run from a known offset + (e.g. after already writing N shards from the first M docs). + """ + with docs_path.open("r", encoding="utf-8") as fh: + skipped = 0 + for line in fh: + line = line.strip() + if not line: + continue + if skipped < skip: + skipped += 1 + continue + obj = json.loads(line) + # Support both {"text": ...} and raw strings. + yield obj["text"] if isinstance(obj, dict) else obj + + +def _token_original_byte_counts( + sp: spm.SentencePieceProcessor, + original_text: str, + transformed_text: str, +) -> np.ndarray: + """Per-token canonical (pre-transform) UTF-8 byte counts. + + Delegates to ``surface_piece_original_byte_counts`` in ``lossless_caps.py`` + — the canonical exporter used by the PR #1729 / HF-hosted CaseOps dataset. + Operator pieces (U+E001..U+E004) contribute 0 original bytes; letter pieces + contribute their pre-transform UTF-8 byte count. + """ + proto = sp.encode_as_immutable_proto(transformed_text) + byte_counts = surface_piece_original_byte_counts( + (piece.surface for piece in proto.pieces), + text_transform_name=LOSSLESS_CAPS_CASEOPS_V1, + ) + return np.asarray(list(byte_counts), dtype=np.uint16) + + +def main() -> None: + ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + ap.add_argument("--docs", required=True, type=pathlib.Path, help="Path to docs_selected.jsonl") + ap.add_argument("--out", required=True, type=pathlib.Path, help="Output datasets dir") + ap.add_argument("--sp", required=True, type=pathlib.Path, help="Path to CaseOps SP model") + ap.add_argument("--val-docs", type=int, default=10_000, help="Validation docs count") + ap.add_argument("--skip-docs", type=int, default=0, + help="Skip first N docs before processing (use to continue a partial run)") + ap.add_argument("--start-shard-train", type=int, default=0, + help="Start train shard numbering from this index (use to append to existing shards)") + ap.add_argument("--max-docs", type=int, default=0, + help="Stop after processing this many docs (0 = no limit)") + args = ap.parse_args() + + sp = spm.SentencePieceProcessor(model_file=str(args.sp)) + print(f"loaded sp: vocab={sp.vocab_size()}", flush=True) + + train_out = args.out / "datasets" / "fineweb10B_sp8192_lossless_caps_caseops_v1_reserved" + train_out.mkdir(parents=True, exist_ok=True) + + val_buf_tokens: list[int] = [] + val_buf_bytes: list[int] = [] + train_buf: list[int] = [] + val_written = 0 + train_written = args.start_shard_train + n_docs = 0 + + for text in _iter_docs(args.docs, skip=args.skip_docs): + transformed = encode_lossless_caps_v2(text) + token_ids = [BOS_ID] + sp.encode(transformed, out_type=int) + if n_docs < args.val_docs: + # Validation doc — also compute byte sidecar + byte_counts = _token_original_byte_counts(sp, text, transformed) + val_buf_tokens.extend(token_ids) + val_buf_bytes.append(0) # BOS contributes 0 original bytes + val_buf_bytes.extend(int(b) for b in byte_counts) + if len(val_buf_tokens) >= SHARD_TOKENS: + _write_shard(train_out / f"fineweb_val_{val_written:06d}.bin", + np.array(val_buf_tokens[:SHARD_TOKENS], dtype=np.uint16)) + _write_shard(train_out / f"fineweb_val_bytes_{val_written:06d}.bin", + np.array(val_buf_bytes[:SHARD_TOKENS], dtype=np.uint16)) + val_buf_tokens = val_buf_tokens[SHARD_TOKENS:] + val_buf_bytes = val_buf_bytes[SHARD_TOKENS:] + val_written += 1 + else: + train_buf.extend(token_ids) + if len(train_buf) >= SHARD_TOKENS: + _write_shard(train_out / f"fineweb_train_{train_written:06d}.bin", + np.array(train_buf[:SHARD_TOKENS], dtype=np.uint16)) + train_buf = train_buf[SHARD_TOKENS:] + train_written += 1 + n_docs += 1 + if n_docs % 10_000 == 0: + print(f" processed {n_docs} docs train_shards={train_written} val_shards={val_written}", flush=True) + if args.max_docs > 0 and n_docs >= args.max_docs: + break + + # Flush tail buffers into final (possibly short) shards. + if val_buf_tokens: + _write_shard(train_out / f"fineweb_val_{val_written:06d}.bin", + np.array(val_buf_tokens, dtype=np.uint16)) + _write_shard(train_out / f"fineweb_val_bytes_{val_written:06d}.bin", + np.array(val_buf_bytes, dtype=np.uint16)) + if train_buf: + _write_shard(train_out / f"fineweb_train_{train_written:06d}.bin", + np.array(train_buf, dtype=np.uint16)) + + print(f"done. docs={n_docs} train_shards={train_written + (1 if train_buf else 0)} val_shards={val_written + (1 if val_buf_tokens else 0)}") + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-05-01_SP8192_PR2130Base_GW8/requirements.txt b/records/track_10min_16mb/2026-05-01_SP8192_PR2130Base_GW8/requirements.txt new file mode 100644 index 0000000000..b6c55e13aa --- /dev/null +++ b/records/track_10min_16mb/2026-05-01_SP8192_PR2130Base_GW8/requirements.txt @@ -0,0 +1,13 @@ +# Python deps. Install with: pip install -r requirements.txt +torch==2.9.1+cu128 +sentencepiece +brotli +huggingface_hub +numpy +python-minifier + +# FlashAttention 3 must be installed separately (not on PyPI): +# pip install --no-deps flash_attn_3 --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291/ + +# System dep (apt): lrzip (used by per-group compressor) +# apt-get install -y lrzip diff --git a/records/track_10min_16mb/2026-05-01_SP8192_PR2130Base_GW8/submission.json b/records/track_10min_16mb/2026-05-01_SP8192_PR2130Base_GW8/submission.json new file mode 100644 index 0000000000..9ff9a68b89 --- /dev/null +++ b/records/track_10min_16mb/2026-05-01_SP8192_PR2130Base_GW8/submission.json @@ -0,0 +1,83 @@ +{ + "author": "Benjamin Hadad", + "github_id": "codemath3000", + "name": "PR #1797 base + n-gram tilt + AsymLogit + #2060 levers + NUM_PHASES=1 + GATE_WINDOW=8", + "blurb": "Single-knob (resp. two-knob) refinement on top of PR #2130: GATE_WINDOW=8. Every other hyperparameter, env var, and code path is byte-for-byte the PR #2130 reproduction command. Full validation target coverage preserved (val_tokens == target_tokens == 47851520).", + "date": "2026-05-01", + "track": "10min_16mb", + "val_loss": null, + "val_bpb": null, + "val_loss_std": null, + "val_bpb_std": null, + "seeds": [ + 314, + 42, + 0 + ], + "seed_results": { + "314": { + "val_loss": null, + "val_bpb": null, + "artifact_bytes": null, + "step_avg_ms": null, + "train_wallclock_ms": null, + "eval_time_s": null, + "val_tokens": 47851520, + "target_tokens": 47851520, + "pre_quant_val_bpb": null, + "quantized_val_bpb": null + }, + "42": { + "val_loss": null, + "val_bpb": null, + "artifact_bytes": null, + "step_avg_ms": null, + "train_wallclock_ms": null, + "eval_time_s": null, + "val_tokens": 47851520, + "target_tokens": 47851520, + "pre_quant_val_bpb": null, + "quantized_val_bpb": null + }, + "0": { + "val_loss": null, + "val_bpb": null, + "artifact_bytes": null, + "step_avg_ms": null, + "train_wallclock_ms": null, + "eval_time_s": null, + "val_tokens": 47851520, + "target_tokens": 47851520, + "pre_quant_val_bpb": null, + "quantized_val_bpb": null + } + }, + "comparison_baseline": "PR #2130", + "comparison_baseline_bpb": 1.0567, + "delta_vs_baseline_bpb": null, + "merged_leaderboard_baseline": "PR #1855", + "merged_leaderboard_baseline_bpb": 1.06107587, + "delta_vs_leaderboard_bpb": null, + "artifact_bytes_max": null, + "train_wallclock_ms_max": null, + "step_avg_ms_mean": null, + "eval_time_s_mean": null, + "eval_time_s_max": null, + "hardware": "8xH100 80GB SXM", + "pytorch_version": "2.9.1+cu128", + "lever_changes": [ + { + "name": "GATE_WINDOW", + "from": 12, + "to": 8 + } + ], + "compliance": { + "artifact_under_16mb": null, + "train_under_600s": null, + "eval_under_600s": null, + "n_seeds": 3, + "full_validation_targets": true, + "val_tokens_equal_target_tokens": true + } +} \ No newline at end of file diff --git a/records/track_10min_16mb/2026-05-01_SP8192_PR2130Base_GW8/tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model b/records/track_10min_16mb/2026-05-01_SP8192_PR2130Base_GW8/tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model new file mode 100644 index 0000000000..fffc8bb306 Binary files /dev/null and b/records/track_10min_16mb/2026-05-01_SP8192_PR2130Base_GW8/tokenizers/fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model differ diff --git a/records/track_10min_16mb/2026-05-01_SP8192_PR2130Base_GW8/train_gpt.py b/records/track_10min_16mb/2026-05-01_SP8192_PR2130Base_GW8/train_gpt.py new file mode 100644 index 0000000000..b9b8401656 --- /dev/null +++ b/records/track_10min_16mb/2026-05-01_SP8192_PR2130Base_GW8/train_gpt.py @@ -0,0 +1,4393 @@ +import base64, collections, copy, fcntl, glob, io, lzma, math, os +from pathlib import Path +import random, re, subprocess, sys, time, uuid, numpy as np, sentencepiece as spm, torch, torch.distributed as dist, torch.nn.functional as F +from torch import Tensor, nn +from flash_attn_interface import ( + flash_attn_func as flash_attn_3_func, + flash_attn_varlen_func, +) +from concurrent.futures import ThreadPoolExecutor +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + + +# ===== Fused softcapped cross-entropy (Triton) — training-only path ===== +# Replaces the eager +# logits_softcap = softcap * tanh(logits / softcap) +# F.cross_entropy(logits_softcap.float(), targets, reduction="mean") +# sequence with a single fused kernel that reads logits_proj once, applies +# softcap in-register, and computes (LSE, loss) in one streaming pass. The +# backward kernel mirrors the forward so there's no stored softcapped logits. +# Numerically identical to the eager path up to fp32 accumulation differences. +_FUSED_CE_LIBRARY = "pgsubmission1draft7fusedce" +_FUSED_CE_BLOCK_SIZE = 1024 +_FUSED_CE_NUM_WARPS = 4 + + +@triton.jit +def _softcapped_ce_fwd_kernel( + logits_ptr, losses_ptr, lse_ptr, targets_ptr, + stride_logits_n, stride_logits_v, + n_rows, n_cols, softcap, + block_size: tl.constexpr, +): + row_idx = tl.program_id(0).to(tl.int64) + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + max_val = -float("inf") + sum_exp = 0.0 + A = 2.0 * softcap + inv_C = 2.0 / softcap + for off in range(0, n_cols, block_size): + cols = off + tl.arange(0, block_size) + mask = cols < n_cols + val = tl.load( + logits_row_ptr + cols * stride_logits_v, + mask=mask, other=-float("inf"), + ).to(tl.float32) + z = A * tl.sigmoid(val * inv_C) + z = tl.where(mask, z, -float("inf")) + curr_max = tl.max(z, axis=0) + new_max = tl.maximum(max_val, curr_max) + sum_exp = sum_exp * tl.exp(max_val - new_max) + tl.sum(tl.exp(z - new_max), axis=0) + max_val = new_max + lse = max_val + tl.log(sum_exp) + tl.store(lse_ptr + row_idx, lse) + target = tl.load(targets_ptr + row_idx).to(tl.int32) + target_val = tl.load(logits_row_ptr + target * stride_logits_v).to(tl.float32) + target_z = A * tl.sigmoid(target_val * inv_C) + tl.store(losses_ptr + row_idx, lse - target_z) + + +@triton.jit +def _softcapped_ce_bwd_kernel( + grad_logits_ptr, grad_losses_ptr, lse_ptr, logits_ptr, targets_ptr, + stride_logits_n, stride_logits_v, + stride_grad_n, stride_grad_v, + n_rows, n_cols, softcap, + block_size: tl.constexpr, +): + row_idx = tl.program_id(0).to(tl.int64) + logits_row_ptr = logits_ptr + row_idx * stride_logits_n + grad_row_ptr = grad_logits_ptr + row_idx * stride_grad_n + lse = tl.load(lse_ptr + row_idx) + grad_loss = tl.load(grad_losses_ptr + row_idx).to(tl.float32) + target = tl.load(targets_ptr + row_idx).to(tl.int32) + A = 2.0 * softcap + inv_C = 2.0 / softcap + dz_dx_scale = A * inv_C + for off in range(0, n_cols, block_size): + cols = off + tl.arange(0, block_size) + mask = cols < n_cols + val = tl.load( + logits_row_ptr + cols * stride_logits_v, + mask=mask, other=0.0, + ).to(tl.float32) + sigmoid_u = tl.sigmoid(val * inv_C) + z = A * sigmoid_u + probs = tl.exp(z - lse) + grad_z = grad_loss * (probs - tl.where(cols == target, 1.0, 0.0)) + grad_x = grad_z * (dz_dx_scale * sigmoid_u * (1.0 - sigmoid_u)) + tl.store(grad_row_ptr + cols * stride_grad_v, grad_x, mask=mask) + + +def _validate_softcapped_ce_inputs( + logits: Tensor, targets: Tensor, softcap: float, +) -> tuple[Tensor, Tensor]: + if logits.ndim != 2: + raise ValueError(f"Expected logits.ndim=2, got {logits.ndim}") + if targets.ndim != 1: + raise ValueError(f"Expected targets.ndim=1, got {targets.ndim}") + if logits.shape[0] != targets.shape[0]: + raise ValueError( + f"Expected matching rows, got logits={tuple(logits.shape)} targets={tuple(targets.shape)}" + ) + if not logits.is_cuda or not targets.is_cuda: + raise ValueError("softcapped_cross_entropy requires CUDA tensors") + if softcap <= 0.0: + raise ValueError(f"softcap must be positive, got {softcap}") + if logits.dtype not in (torch.float16, torch.bfloat16, torch.float32): + raise ValueError(f"Unsupported logits dtype: {logits.dtype}") + logits = logits.contiguous() + targets = targets.contiguous() + if targets.dtype != torch.int64: + targets = targets.to(dtype=torch.int64) + return logits, targets + + +@torch.library.custom_op(f"{_FUSED_CE_LIBRARY}::softcapped_ce", mutates_args=()) +def softcapped_ce_op(logits: Tensor, targets: Tensor, softcap: float) -> tuple[Tensor, Tensor]: + logits, targets = _validate_softcapped_ce_inputs(logits, targets, float(softcap)) + n_rows, n_cols = logits.shape + losses = torch.empty((n_rows,), device=logits.device, dtype=torch.float32) + lse = torch.empty((n_rows,), device=logits.device, dtype=torch.float32) + _softcapped_ce_fwd_kernel[(n_rows,)]( + logits, losses, lse, targets, + logits.stride(0), logits.stride(1), + n_rows, n_cols, float(softcap), + block_size=_FUSED_CE_BLOCK_SIZE, num_warps=_FUSED_CE_NUM_WARPS, + ) + return losses, lse + + +@softcapped_ce_op.register_fake +def _(logits: Tensor, targets: Tensor, softcap: float): + if logits.ndim != 2 or targets.ndim != 1: + raise ValueError("softcapped_ce fake impl expects 2D logits and 1D targets") + if logits.shape[0] != targets.shape[0]: + raise ValueError( + f"Expected matching rows, got logits={tuple(logits.shape)} targets={tuple(targets.shape)}" + ) + n_rows = logits.shape[0] + return ( + logits.new_empty((n_rows,), dtype=torch.float32), + logits.new_empty((n_rows,), dtype=torch.float32), + ) + + +@torch.library.custom_op(f"{_FUSED_CE_LIBRARY}::softcapped_ce_backward", mutates_args=()) +def softcapped_ce_backward_op( + logits: Tensor, targets: Tensor, lse: Tensor, grad_losses: Tensor, softcap: float, +) -> Tensor: + logits, targets = _validate_softcapped_ce_inputs(logits, targets, float(softcap)) + lse = lse.contiguous() + grad_losses = grad_losses.contiguous().to(dtype=torch.float32) + if lse.ndim != 1 or grad_losses.ndim != 1: + raise ValueError("Expected 1D lse and grad_losses") + if lse.shape[0] != logits.shape[0] or grad_losses.shape[0] != logits.shape[0]: + raise ValueError( + f"Expected row-aligned lse/grad_losses, got logits={tuple(logits.shape)} " + f"lse={tuple(lse.shape)} grad_losses={tuple(grad_losses.shape)}" + ) + grad_logits = torch.empty_like(logits) + n_rows, n_cols = logits.shape + _softcapped_ce_bwd_kernel[(n_rows,)]( + grad_logits, grad_losses, lse, logits, targets, + logits.stride(0), logits.stride(1), + grad_logits.stride(0), grad_logits.stride(1), + n_rows, n_cols, float(softcap), + block_size=_FUSED_CE_BLOCK_SIZE, num_warps=_FUSED_CE_NUM_WARPS, + ) + return grad_logits + + +@softcapped_ce_backward_op.register_fake +def _(logits: Tensor, targets: Tensor, lse: Tensor, grad_losses: Tensor, softcap: float): + if logits.ndim != 2 or targets.ndim != 1 or lse.ndim != 1 or grad_losses.ndim != 1: + raise ValueError("softcapped_ce_backward fake impl expects 2D logits and 1D row tensors") + if ( + logits.shape[0] != targets.shape[0] + or logits.shape[0] != lse.shape[0] + or logits.shape[0] != grad_losses.shape[0] + ): + raise ValueError("softcapped_ce_backward fake impl expects row-aligned tensors") + return logits.new_empty(logits.shape) + + +def _softcapped_ce_setup_context( + ctx: torch.autograd.function.FunctionCtx, inputs, output, +) -> None: + logits, targets, softcap = inputs + _losses, lse = output + ctx.save_for_backward(logits, targets, lse) + ctx.softcap = float(softcap) + + +def _softcapped_ce_backward( + ctx: torch.autograd.function.FunctionCtx, grad_losses: Tensor, grad_lse: "Tensor | None", +): + del grad_lse + logits, targets, lse = ctx.saved_tensors + grad_logits = torch.ops.pgsubmission1draft7fusedce.softcapped_ce_backward( + logits, targets, lse, grad_losses, ctx.softcap + ) + return grad_logits, None, None + + +softcapped_ce_op.register_autograd( + _softcapped_ce_backward, setup_context=_softcapped_ce_setup_context, +) + + +def softcapped_cross_entropy( + logits: Tensor, targets: Tensor, softcap: float, reduction: str = "mean", +) -> Tensor: + losses, _lse = torch.ops.pgsubmission1draft7fusedce.softcapped_ce( + logits, targets, float(softcap) + ) + if reduction == "none": + return losses + if reduction == "sum": + return losses.sum() + if reduction == "mean": + return losses.mean() + raise ValueError(f"Unsupported reduction={reduction!r}") + + +class Hyperparameters: + data_dir = os.environ.get("DATA_DIR", "./data/") + seed = int(os.environ.get("SEED", 1337)) + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_frac = float(os.environ.get("WARMDOWN_FRAC", 0.75)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786432)) + # Fused softcapped CE (Triton). Training-only — forward_logits eval path still uses + # eager softcap+F.cross_entropy. Default ON since validated as at-worst neutral. + fused_ce_enabled = bool(int(os.environ.get("FUSED_CE_ENABLED", "1"))) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 6e2)) + val_batch_tokens = int(os.environ.get("VAL_BATCH_TOKENS", 524288)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 8192)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 11)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 4.0)) + skip_gates_enabled = bool(int(os.environ.get("SKIP_GATES_ENABLED", "1"))) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 3e1)) + asym_logit_rescale = bool(int(os.environ.get("ASYM_LOGIT_RESCALE", "0"))) + rope_base = float(os.environ.get("ROPE_BASE", 1e4)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + rope_train_seq_len = int(os.environ.get("ROPE_TRAIN_SEQ_LEN", 2048)) + rope_yarn = bool(int(os.environ.get("ROPE_YARN", "0"))) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 5.0)) + # Layer 1: per-layer QK-Gain init schedule. Comma-separated floats, one per physical + # layer. Falls back to uniform qk_gain_init if empty. Schedule is an initialization + # choice — q_gain remains trainable and evolves from this starting point. + _qk_sched_raw = os.environ.get("QK_GAIN_INIT_SCHEDULE", "") + qk_gain_schedule = ( + [float(x) for x in _qk_sched_raw.split(",") if x.strip()] + if _qk_sched_raw.strip() else [] + ) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + loop_start = int(os.environ.get("LOOP_START", 3)) + loop_end = int(os.environ.get("LOOP_END", 5)) + enable_looping_at = float(os.environ.get("ENABLE_LOOPING_AT", 0.35)) + parallel_start_layer = int(os.environ.get("PARALLEL_START_LAYER", 8)) + parallel_final_lane = os.environ.get("PARALLEL_FINAL_LANE", "mean") + min_lr = float(os.environ.get("MIN_LR", 0.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.026)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.97)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float( + os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92) + ) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + muon_row_normalize = bool(int(os.environ.get("MUON_ROW_NORMALIZE", "1"))) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-08)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + adam_wd = float(os.environ.get("ADAM_WD", 0.02)) + muon_wd = float(os.environ.get("MUON_WD", 0.095)) + embed_wd = float(os.environ.get("EMBED_WD", 0.085)) + # Layer 3: AdamHD Huber weight decay for Muon. Replaces L2 WD with Huber regularizer: + # quadratic below |w| <= delta (standard L2 behavior), linear above (clips large-weight + # gradient contribution). Specifically suppresses outlier weights that dominate int6 + # quantization error. delta=0 falls back to standard L2 decay. + muon_huber_wd = bool(int(os.environ.get("MUON_HUBER_WD", "0"))) + muon_huber_delta = float(os.environ.get("MUON_HUBER_DELTA", "0.1")) + ema_decay = float(os.environ.get("EMA_DECAY", 0.9965)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "1"))) + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 96)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.0001)) + ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 48)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 2048)) + ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) + ttt_grad_steps = int(os.environ.get("TTT_GRAD_STEPS", 1)) + ttt_grad_steps_clean = bool(int(os.environ.get("TTT_GRAD_STEPS_CLEAN", "0"))) + # PR #1145 online n-gram tilt (AnirudhRahul, valerio-endorsed). Causal, + # normalized, prefix-only experts; closed-form multiplicative-boost-with-renorm + # applied to per-token NLL at scoring time only. See online_ngram_tilt.py. + ngram_tilt_enabled = bool(int(os.environ.get("NGRAM_TILT_ENABLED", "0"))) + token_order = int(os.environ.get("TOKEN_ORDER", "16")) + token_threshold = float(os.environ.get("TOKEN_THRESHOLD", "0.800")) + token_boost = float(os.environ.get("TOKEN_BOOST", "2.625")) + # within-doc and word-start experts gate on target_i properties (is_new_word, + # is_boundary applied to the token being SCORED), which violates C1 causality. + # Defaults 99.0 ensure their gates never fire. Token-only is the legal subset + # (confirmed by PR #1514 merge precedent). + within_tau = float(os.environ.get("WITHIN_TAU", "99.0")) + within_boost = float(os.environ.get("WITHIN_BOOST", "0.0")) + word_order = int(os.environ.get("WORD_ORDER", "4")) + word_normalize = os.environ.get("WORD_NORMALIZE", "strip_punct_lower") + word_tau = float(os.environ.get("WORD_TAU", "99.0")) + word_boost = float(os.environ.get("WORD_BOOST", "0.0")) + agree_add_boost = float(os.environ.get("AGREE_ADD_BOOST", "0.500")) + ngram_hint_precompute_outside = bool(int(os.environ.get("NGRAM_HINT_PRECOMPUTE_OUTSIDE", "1"))) + ttt_weight_decay = float(os.environ.get("TTT_WEIGHT_DECAY", 1.0)) + ttt_beta1 = float(os.environ.get("TTT_BETA1", 0)) + ttt_beta2 = float(os.environ.get("TTT_BETA2", 0.999)) + ttt_k_lora = bool(int(os.environ.get("TTT_K_LORA", "1"))) + ttt_mlp_lora = bool(int(os.environ.get("TTT_MLP_LORA", "1"))) + ttt_o_lora = bool(int(os.environ.get("TTT_O_LORA", "1"))) + ttt_optimizer = os.environ.get("TTT_OPTIMIZER", "adam") + ttt_eval_batches = os.environ.get("TTT_EVAL_BATCHES", "") + val_doc_fraction = float(os.environ.get("VAL_DOC_FRACTION", 1.0)) + compressor = os.environ.get("COMPRESSOR", "brotli") + gptq_calibration_batches = int(os.environ.get("GPTQ_CALIBRATION_BATCHES", 16)) + gptq_reserve_seconds = float(os.environ.get("GPTQ_RESERVE_SECONDS", 4.0)) + phased_ttt_prefix_docs = int(os.environ.get("PHASED_TTT_PREFIX_DOCS", 2000)) + phased_ttt_num_phases = int(os.environ.get("PHASED_TTT_NUM_PHASES", 1)) + global_ttt_lr = float(os.environ.get("GLOBAL_TTT_LR", 0.001)) + global_ttt_momentum = float(os.environ.get("GLOBAL_TTT_MOMENTUM", 0.9)) + global_ttt_epochs = int(os.environ.get("GLOBAL_TTT_EPOCHS", 1)) + global_ttt_chunk_tokens = int(os.environ.get("GLOBAL_TTT_CHUNK_TOKENS", 32768)) + global_ttt_batch_seqs = int(os.environ.get("GLOBAL_TTT_BATCH_SEQS", 32)) + global_ttt_warmup_start_lr = float(os.environ.get("GLOBAL_TTT_WARMUP_START_LR", 0.0)) + global_ttt_warmup_chunks = int(os.environ.get("GLOBAL_TTT_WARMUP_CHUNKS", 0)) + global_ttt_grad_clip = float(os.environ.get("GLOBAL_TTT_GRAD_CLIP", 1.0)) + global_ttt_respect_doc_boundaries = bool(int(os.environ.get("GLOBAL_TTT_RESPECT_DOC_BOUNDARIES", "1"))) + # Layer 4: LaCT global TTT optimizer. "sgd" = original SGD (default). "muon" = Muon + # Newton-Schulz orthogonalized updates for matrix params, SGD for scalars/embeddings. + # LaCT (arxiv 2505.23884) uses large document chunks + Muon fast-weight updates to + # achieve 70% GPU utilization vs <5% for per-token TTT, enabling more TTT epochs. + global_ttt_optimizer = os.environ.get("GLOBAL_TTT_OPTIMIZER", "sgd") + global_ttt_muon_ns_steps = int(os.environ.get("GLOBAL_TTT_MUON_NS_STEPS", "5")) + global_ttt_muon_nesterov = bool(int(os.environ.get("GLOBAL_TTT_MUON_NESTEROV", "1"))) + # Layer 2: OptRot pre-quantization Hadamard rotation. Rotates (W_up, W_down) and + # (W_v, W_o) pairs via Hadamard matrices, redistributing outlier weights before GPTQ. + # Orthogonal transformation: model outputs are preserved exactly; quantization error + # is reduced 30-50% (paper: arxiv 2512.24124). Zero artifact cost (fused into weights). + optrot_enabled = bool(int(os.environ.get("OPTROT_ENABLED", "0"))) + matrix_bits = int(os.environ.get("MATRIX_BITS", 6)) + embed_bits = int(os.environ.get("EMBED_BITS", 8)) + matrix_clip_sigmas = float(os.environ.get("MATRIX_CLIP_SIGMAS", 12.85)) + embed_clip_sigmas = float(os.environ.get("EMBED_CLIP_SIGMAS", 2e1)) + mlp_clip_sigmas = float(os.environ.get("MLP_CLIP_SIGMAS", 10.0)) + attn_clip_sigmas = float(os.environ.get("ATTN_CLIP_SIGMAS", 13.0)) + # AttnOutGate (per-head multiplicative output gate, PR #1667 MarioPaerle). + # Zero-init weight: 2*sigmoid(0)=1 -> transparent at start. Source defaults to + # block input x ('proj'); 'q' uses raw Q projection output. + attn_out_gate_enabled = bool(int(os.environ.get("ATTN_OUT_GATE_ENABLED", "0"))) + attn_out_gate_src = os.environ.get("ATTN_OUT_GATE_SRC", "proj") + # SmearGate (input-dependent forward-1 token smear, modded-nanogpt @classiclarryd + # via PR #1667). x_t <- x_t + lam * sigmoid(W*x_t[:gate_window]) * x_{t-1}. + # lam=0 + W=0 -> transparent at init. + smear_gate_enabled = bool(int(os.environ.get("SMEAR_GATE_ENABLED", "0"))) + # Window: first GATE_WINDOW dims of the source feed the gate projection. + gate_window = int(os.environ.get("GATE_WINDOW", 12)) + # Gated Attention (Qwen, NeurIPS 2025 Best Paper, arXiv:2505.06708; + # qiuzh20/gated_attention). Per-head sigmoid gate on SDPA output, BEFORE + # out_proj. Gate input = full block input x (paper's headwise G1 variant + # driven from hidden_states). W_g shape (num_heads, dim), plain sigmoid. + # Near-zero init gives g~0.5 at step 0 (half attention output); per-block + # attn_scale (init 1.0) compensates during training. Name contains + # "attn_gate" so CONTROL_TENSOR_NAME_PATTERNS routes it to scalar AdamW. + gated_attn_enabled = bool(int(os.environ.get("GATED_ATTN_ENABLED", "0"))) + gated_attn_init_std = float(os.environ.get("GATED_ATTN_INIT_STD", 0.01)) + # Dedicated int8-per-row quantization for `attn_gate_w` tensors. These are + # small ((num_heads, dim) = (8, 512) = 4096 params) and bypass GPTQ via the + # numel<=65536 passthrough branch -> stored as fp16 (8 KB/layer, ~65 KB total + # compressed). int8-per-row cuts the raw tensor in half with negligible BPB + # impact: scales per head (8 values), symmetric quant over [-127, 127]. + # No Hessian needed (gate weights not in collect_hessians()). + gated_attn_quant_gate = bool(int(os.environ.get("GATED_ATTN_QUANT_GATE", "0"))) + # Sparse Attention Gate (modded-nanogpt-style). Keeps dense SDPA and only + # swaps the output-gate input to the first GATE_WINDOW residual dims. + # W_g: (num_heads, gate_window) = (8, 12) = 96 params/layer (~44K total), + # vs dense GatedAttn's (8, 512) = 4K/layer (~44K diff). Name "attn_gate_w" + # is shared so quant routing and int8 gate passthrough Just Work. Gate + # passthrough int8 still applies via GATED_ATTN_QUANT_GATE=1. + # Mutually exclusive with ATTN_OUT_GATE_ENABLED and GATED_ATTN_ENABLED. + sparse_attn_gate_enabled = bool(int(os.environ.get("SPARSE_ATTN_GATE_ENABLED", "0"))) + sparse_attn_gate_init_std = float(os.environ.get("SPARSE_ATTN_GATE_INIT_STD", 0.0)) + sparse_attn_gate_scale = float(os.environ.get("SPARSE_ATTN_GATE_SCALE", 1.0)) + # LQER asymmetric rank-k correction on top-K quant-error tensors (PR #1530 v2 port). + # Computes SVD of E = W_fp - W_quant, packs top-r A,B as INT2/INT4 (asym) or INTk (sym). + lqer_enabled = bool(int(os.environ.get("LQER_ENABLED", "1"))) + lqer_rank = int(os.environ.get("LQER_RANK", 4)) + lqer_top_k = int(os.environ.get("LQER_TOP_K", 3)) + lqer_factor_bits = int(os.environ.get("LQER_FACTOR_BITS", 4)) + lqer_asym_enabled = bool(int(os.environ.get("LQER_ASYM_ENABLED", "1"))) + lqer_asym_group = int(os.environ.get("LQER_ASYM_GROUP", "64")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + is_main_process = rank == 0 + grad_accum_steps = 8 // world_size + # CaseOps integration: optional override of dataset root + tokenizer path. + # When CASEOPS_ENABLED=1, the wrapper loads a per-token byte sidecar + # (fineweb_val_bytes_*.bin, identical shard layout to val_*.bin) and uses + # it as the canonical raw-byte budget for BPB accounting. The sidecar + # REPLACES the build_sentencepiece_luts byte-counting path entirely. + caseops_enabled = bool(int(os.environ.get("CASEOPS_ENABLED", "0"))) + _default_caseops_data = os.path.join( + data_dir, + "datasets", + "fineweb10B_sp8192_caseops", + "datasets", + "datasets", + "fineweb10B_sp8192_lossless_caps_caseops_v1_reserved", + ) + _default_caseops_tok = os.path.join( + data_dir, + "datasets", + "fineweb10B_sp8192_caseops", + "datasets", + "tokenizers", + "fineweb_8192_bpe_lossless_caps_caseops_v1_reserved.model", + ) + if caseops_enabled: + datasets_dir = os.environ.get("DATA_PATH", _default_caseops_data) + tokenizer_path = os.environ.get("TOKENIZER_PATH", _default_caseops_tok) + else: + datasets_dir = os.environ.get( + "DATA_PATH", + os.path.join(data_dir, "datasets", f"fineweb10B_sp{vocab_size}"), + ) + tokenizer_path = os.environ.get( + "TOKENIZER_PATH", + os.path.join(data_dir, "tokenizers", f"fineweb_{vocab_size}_bpe.model"), + ) + train_files = os.path.join(datasets_dir, "fineweb_train_*.bin") + val_files = os.path.join(datasets_dir, "fineweb_val_*.bin") + val_bytes_files = os.path.join(datasets_dir, "fineweb_val_bytes_*.bin") + artifact_dir = os.environ.get("ARTIFACT_DIR", "") + logfile = ( + os.path.join(artifact_dir, f"{run_id}.txt") + if artifact_dir + else f"logs/{run_id}.txt" + ) + model_path = ( + os.path.join(artifact_dir, "final_model.pt") + if artifact_dir + else "final_model.pt" + ) + quantized_model_path = ( + os.path.join(artifact_dir, "final_model.int6.ptz") + if artifact_dir + else "final_model.int6.ptz" + ) + + +_logger_hparams = None + + +def set_logging_hparams(h): + global _logger_hparams + _logger_hparams = h + + +def log(msg, console=True): + if _logger_hparams is None: + print(msg) + return + if _logger_hparams.is_main_process: + if console: + print(msg) + if _logger_hparams.logfile is not None: + with open(_logger_hparams.logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + +class ValidationData: + def __init__(self, h, device): + self.sp = spm.SentencePieceProcessor(model_file=h.tokenizer_path) + if int(self.sp.vocab_size()) != h.vocab_size: + raise ValueError( + f"VOCAB_SIZE={h.vocab_size} does not match tokenizer vocab_size={int(self.sp.vocab_size())}" + ) + self.val_tokens = load_validation_tokens(h.val_files, h.eval_seq_len) + ( + self.base_bytes_lut, + self.has_leading_space_lut, + self.is_boundary_token_lut, + ) = build_sentencepiece_luts(self.sp, h.vocab_size, device) + # CaseOps: when enabled, load per-token byte sidecar and stash it as a + # CPU tensor aligned 1:1 with self.val_tokens. eval_val/eval_val_ttt + # branches use this as the canonical raw-byte budget per token. + self.caseops_enabled = bool(getattr(h, "caseops_enabled", False)) + self.val_bytes = None + if self.caseops_enabled: + self.val_bytes = load_validation_byte_sidecar( + h.val_bytes_files, h.eval_seq_len, self.val_tokens.numel() + ) + + +def build_sentencepiece_luts(sp, vocab_size, device): + sp_vocab_size = int(sp.vocab_size()) + assert ( + sp.piece_to_id("▁") != sp.unk_id() + ), "Tokenizer must have '▁' (space) as its own token for correct BPB byte counting" + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern, seq_len): + # Filter out CaseOps byte sidecar shards which share the val_*.bin glob. + files = [ + Path(p) + for p in sorted(glob.glob(pattern)) + if "_bytes_" not in Path(p).name + ] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = (tokens.numel() - 1) // seq_len * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def load_validation_byte_sidecar(pattern, seq_len, expected_len): + """Load CaseOps per-token byte sidecar(s). Same shard layout as token shards + (256 int32 header + uint16 array). Each entry = canonical raw-text byte + budget for that token in the corresponding val shard. Returns a CPU + int16 tensor sliced to match expected_len (i.e. val_tokens length).""" + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No byte sidecar files for pattern: {pattern}") + shards = [load_data_shard(file) for file in files] + # load_data_shard returns uint16 — that's exactly what the sidecar stores. + bytes_full = torch.cat(shards).contiguous() + if bytes_full.numel() < expected_len: + raise ValueError( + f"Byte sidecar too short: {bytes_full.numel()} < val_tokens {expected_len}" + ) + return bytes_full[:expected_len].to(torch.int32) + + +def load_data_shard(file): + header_bytes = 256 * np.dtype(" 0: + pos = start + while pos < end: + seg_starts.append(pos) + pos += max_doc_len + else: + seg_starts.append(start) + boundaries = seg_starts + [total_len] + padded_len = get_next_multiple_of_n(len(boundaries), bucket_size) + cu = torch.full((padded_len,), total_len, dtype=torch.int32, device=device) + cu[: len(boundaries)] = torch.tensor(boundaries, dtype=torch.int32, device=device) + seg_ends = seg_starts[1:] + [total_len] + max_seqlen = max(end - start for start, end in zip(seg_starts, seg_ends)) + return cu, max_seqlen + +class DocumentPackingLoader: + _shard_pool = ThreadPoolExecutor(1) + + def __init__(self, h, device, cu_bucket_size=64): + self.rank = h.rank + self.world_size = h.world_size + self.device = device + self.cu_bucket_size = cu_bucket_size + self.max_seq_len = h.train_seq_len + all_files = [Path(p) for p in sorted(glob.glob(h.train_files))] + if not all_files: + raise FileNotFoundError(f"No files found for pattern: {h.train_files}") + self.files = all_files + self.file_iter = iter(self.files) + self._init_shard(load_data_shard(next(self.file_iter))) + self._next_shard = self._submit_next_shard() + self._batch_pool = ThreadPoolExecutor(1) + self._next_batch = None + + def _init_shard(self, tokens): + global BOS_ID + self.tokens = tokens + self.shard_size = tokens.numel() + if BOS_ID is None: + BOS_ID = 1 + self.bos_idx = ( + (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + ) + self.cursor = int(self.bos_idx[0]) + + def _submit_next_shard(self): + try: + path = next(self.file_iter) + return self._shard_pool.submit(load_data_shard, path) + except StopIteration: + return None + + def _advance_shard(self): + if self._next_shard is None: + self.file_iter = iter(self.files) + self._next_shard = self._shard_pool.submit( + load_data_shard, next(self.file_iter) + ) + self._init_shard(self._next_shard.result()) + self._next_shard = self._submit_next_shard() + + def _local_doc_starts(self, local_start, total_len): + lo = np.searchsorted(self.bos_idx, local_start, side="left") + hi = np.searchsorted(self.bos_idx, local_start + total_len, side="left") + return (self.bos_idx[lo:hi] - local_start).tolist() + + def _prepare_batch(self, num_tokens_local, max_seq_len): + per_rank_span = num_tokens_local + 1 + global_span = per_rank_span * self.world_size + while self.cursor + global_span > self.shard_size: + self._advance_shard() + local_start = self.cursor + self.rank * per_rank_span + buf = self.tokens[local_start : local_start + per_rank_span] + inputs = buf[:-1].to(dtype=torch.int64).pin_memory() + targets = buf[1:].to(dtype=torch.int64).pin_memory() + starts = self._local_doc_starts(local_start, inputs.numel()) + cu_seqlens, max_seqlen = _build_cu_seqlens( + starts, inputs.numel(), inputs.device, max_seq_len, self.cu_bucket_size + ) + cu_seqlens = cu_seqlens.pin_memory() + self.cursor += global_span + return inputs, targets, cu_seqlens, max_seqlen + + def next_batch(self, global_tokens, grad_accum_steps): + num_tokens_local = global_tokens // (self.world_size * grad_accum_steps) + if self._next_batch is not None: + inputs, targets, cu_seqlens, max_seqlen = self._next_batch.result() + else: + inputs, targets, cu_seqlens, max_seqlen = self._prepare_batch( + num_tokens_local, self.max_seq_len + ) + self._next_batch = self._batch_pool.submit( + self._prepare_batch, num_tokens_local, self.max_seq_len + ) + return ( + inputs[None].to(self.device, non_blocking=True), + targets[None].to(self.device, non_blocking=True), + cu_seqlens.to(self.device, non_blocking=True), + max_seqlen, + ) + + +class ShuffledSequenceLoader: + def __init__(self, h, device): + self.world_size = h.world_size + self.seq_len = h.train_seq_len + self.device = device + all_files = [Path(p) for p in sorted(glob.glob(h.train_files))] + if not all_files: + raise FileNotFoundError(f"No files found for pattern: {h.train_files}") + self.files = all_files[h.rank :: h.world_size] + self.rng = np.random.Generator(np.random.PCG64(h.rank)) + self.num_tokens = [_read_num_tokens(f) for f in self.files] + self.start_inds = [[] for _ in self.files] + for si in range(len(self.files)): + self._reset_shard(si) + + def _reset_shard(self, si): + max_phase = min( + self.seq_len - 1, max(0, self.num_tokens[si] - self.seq_len - 1) + ) + phase = int(self.rng.integers(max_phase + 1)) if max_phase > 0 else 0 + num_sequences = (self.num_tokens[si] - 1 - phase) // self.seq_len + sequence_order = self.rng.permutation(num_sequences) + self.start_inds[si] = (phase + sequence_order * self.seq_len).tolist() + + def next_batch(self, global_tokens, grad_accum_steps): + device_tokens = global_tokens // (self.world_size * grad_accum_steps) + device_batch_size = device_tokens // self.seq_len + remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64) + x = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + y = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + for bi in range(device_batch_size): + total = remaining.sum() + if total <= 0: + for si in range(len(self.files)): + self._reset_shard(si) + remaining = np.array( + [len(s) for s in self.start_inds], dtype=np.float64 + ) + total = remaining.sum() + probs = remaining / total + si = int(self.rng.choice(len(self.files), p=probs)) + start_ind = self.start_inds[si].pop() + remaining[si] -= 1 + mm = _get_shard_memmap(self.files[si]) + window = torch.as_tensor( + np.array(mm[start_ind : start_ind + self.seq_len + 1], dtype=np.int64) + ) + x[bi] = window[:-1] + y[bi] = window[1:] + return x.to(self.device, non_blocking=True), y.to( + self.device, non_blocking=True + ) + + +class DocStartSequenceLoader: + """GPTQ calibration loader yielding windows that begin at BOS positions. + + Matches ShuffledSequenceLoader.next_batch() interface exactly. + Used when GPTQ_CALIBRATION_MODE=doc_start to align calibration distribution + with eval (which processes document-structured data with BOS-prepended contexts). + """ + _N_SCAN_SHARDS = 8 + + def __init__(self, h, device, bos_id=1): + self.world_size = h.world_size + self.seq_len = h.train_seq_len + self.device = device + all_files = [Path(p) for p in sorted(glob.glob(h.train_files))] + if not all_files: + raise FileNotFoundError(f"No files found for pattern: {h.train_files}") + rank_files = all_files[h.rank :: h.world_size] + self.rng = np.random.Generator(np.random.PCG64(h.rank + 7919)) + scan_files = rank_files[: self._N_SCAN_SHARDS] + self._windows = [] # (path, start_offset) pairs starting at BOS + t0 = time.perf_counter() + for path in scan_files: + mm = _get_shard_memmap(path) + n = len(mm) + positions = np.where(mm[:] == np.uint16(bos_id))[0] + valid = positions[positions + self.seq_len + 1 <= n] + for pos in valid.tolist(): + self._windows.append((path, pos)) + log(f"DocStartLoader: {len(self._windows)} BOS windows from {len(scan_files)} shards in {time.perf_counter()-t0:.1f}s") + if not self._windows: + raise RuntimeError("DocStartSequenceLoader: no valid BOS windows found — check BOS_ID and shard files") + + def next_batch(self, global_tokens, grad_accum_steps): + device_tokens = global_tokens // (self.world_size * grad_accum_steps) + device_batch_size = device_tokens // self.seq_len + x = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + y = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + idxs = self.rng.integers(0, len(self._windows), size=device_batch_size) + for bi, idx in enumerate(idxs): + path, start = self._windows[int(idx)] + mm = _get_shard_memmap(path) + window = torch.as_tensor( + np.array(mm[start : start + self.seq_len + 1], dtype=np.int64) + ) + x[bi] = window[:-1] + y[bi] = window[1:] + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +class RMSNorm(nn.Module): + def __init__(self, eps=None): + super().__init__() + self.eps = eps + + def forward(self, x): + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x): + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +@triton.jit +def linear_leaky_relu_square_kernel( + a_desc, + b_desc, + c_desc, + aux_desc, + M, + N, + K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + NUM_SMS: tl.constexpr, + FORWARD: tl.constexpr, +): + dtype = tl.bfloat16 + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + tile_id_c = start_pid - NUM_SMS + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + tile_id_c += NUM_SMS + offs_am_c = offs_am + offs_bn_c = offs_bn + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + c0 = acc0.to(dtype) + c1 = acc1.to(dtype) + if not FORWARD: + pre0 = aux_desc.load([offs_am_c, offs_bn_c]) + pre1 = aux_desc.load([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2]) + c0 = c0 * tl.where(pre0 > 0, 2.0 * pre0, 0.5 * pre0) + c1 = c1 * tl.where(pre1 > 0, 2.0 * pre1, 0.5 * pre1) + c_desc.store([offs_am_c, offs_bn_c], c0) + c_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1) + if FORWARD: + aux0 = tl.where(c0 > 0, c0, 0.5 * c0) + aux1 = tl.where(c1 > 0, c1, 0.5 * c1) + aux_desc.store([offs_am_c, offs_bn_c], aux0 * aux0) + aux_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], aux1 * aux1) + + +def linear_leaky_relu_square(a, b, aux=None): + M, K = a.shape + N, K2 = b.shape + assert K == K2 + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + forward = aux is None + if aux is None: + aux = torch.empty((M, N), device=a.device, dtype=a.dtype) + num_sms = torch.cuda.get_device_properties(a.device).multi_processor_count + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 128, 256, 64 + num_stages = 4 if forward else 3 + a_desc = TensorDescriptor.from_tensor(a, [BLOCK_SIZE_M, BLOCK_SIZE_K]) + b_desc = TensorDescriptor.from_tensor(b, [BLOCK_SIZE_N, BLOCK_SIZE_K]) + c_desc = TensorDescriptor.from_tensor(c, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + aux_desc = TensorDescriptor.from_tensor(aux, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2]) + grid = lambda _meta: ( + min(num_sms, triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N)), + ) + linear_leaky_relu_square_kernel[grid]( + a_desc, + b_desc, + c_desc, + aux_desc, + M, + N, + K, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + NUM_SMS=num_sms, + FORWARD=forward, + num_stages=num_stages, + num_warps=8, + ) + if forward: + return c, aux + return c + + +class FusedLinearLeakyReLUSquareFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, w1, w2): + x_flat = x.reshape(-1, x.shape[-1]) + pre, post = linear_leaky_relu_square(x_flat, w1) + out = F.linear(post, w2) + ctx.save_for_backward(x, w1, w2, pre, post) + return out.view(*x.shape[:-1], out.shape[-1]) + + @staticmethod + def backward(ctx, grad_output): + x, w1, w2, pre, post = ctx.saved_tensors + x_flat = x.reshape(-1, x.shape[-1]) + grad_output_flat = grad_output.reshape(-1, grad_output.shape[-1]) + dw2 = grad_output_flat.T @ post + dpre = linear_leaky_relu_square(grad_output_flat, w2.T.contiguous(), aux=pre) + dw1 = dpre.T @ x_flat + dx = dpre @ w1 + return dx.view_as(x), dw1, dw2 + + +FusedLeakyReLUSquareMLP = FusedLinearLeakyReLUSquareFunction.apply + + +@triton.jit +def fused_log_softmax_dual_gather_kernel( + logits_ptr, + target_ids_ptr, + hint_ids_ptr, + log_p_y_out_ptr, + log_q_h_out_ptr, + BT, + V, + BLOCK_V: tl.constexpr, +): + """Single pass over [BT, V] logits; extracts log p(target) and log p(hint).""" + pid = tl.program_id(0) + if pid >= BT: + return + target = tl.load(target_ids_ptr + pid) + hint = tl.load(hint_ids_ptr + pid) + row_offset = pid * V + target_logit = tl.load(logits_ptr + row_offset + target).to(tl.float32) + hint_logit = tl.load(logits_ptr + row_offset + hint).to(tl.float32) + NEG_INF = float("-inf") + max_val = NEG_INF + for v_start in tl.range(0, V, BLOCK_V): + v_offsets = v_start + tl.arange(0, BLOCK_V) + mask = v_offsets < V + chunk = tl.load(logits_ptr + row_offset + v_offsets, mask=mask, other=NEG_INF).to(tl.float32) + max_val = tl.maximum(max_val, tl.max(chunk, axis=0)) + sum_exp = tl.zeros((), dtype=tl.float32) + for v_start in tl.range(0, V, BLOCK_V): + v_offsets = v_start + tl.arange(0, BLOCK_V) + mask = v_offsets < V + chunk = tl.load(logits_ptr + row_offset + v_offsets, mask=mask, other=0.0).to(tl.float32) + sum_exp += tl.sum(tl.where(mask, tl.exp(chunk - max_val), 0.0), axis=0) + log_sum_exp = max_val + tl.log(sum_exp) + tl.store(log_p_y_out_ptr + pid, target_logit - log_sum_exp) + tl.store(log_q_h_out_ptr + pid, hint_logit - log_sum_exp) + + +def fused_log_softmax_dual_gather(logits, target_ids, hint_ids): + """Returns (log_p_y, log_q_h) where p = softmax(logits). No backward needed.""" + bsz, sl, V = logits.shape + BT = bsz * sl + logits_flat = logits.reshape(BT, V).contiguous() + log_p_y_out = torch.empty(BT, dtype=torch.float32, device=logits.device) + log_q_h_out = torch.empty(BT, dtype=torch.float32, device=logits.device) + fused_log_softmax_dual_gather_kernel[(BT,)]( + logits_flat, + target_ids.reshape(BT).contiguous(), + hint_ids.reshape(BT).contiguous(), + log_p_y_out, + log_q_h_out, + BT, V, BLOCK_V=1024, num_warps=8, + ) + return log_p_y_out.reshape(bsz, sl), log_q_h_out.reshape(bsz, sl) + + +class Rotary(nn.Module): + def __init__(self, dim, base=1e4, train_seq_len=1024, rope_dims=0, yarn=True): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.yarn = yarn + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / base ** ( + torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + + def forward(self, seq_len, device, dtype): + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached < seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if self.yarn and seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * scale ** (rd / (rd - 2)) + inv_freq = 1.0 / new_base ** ( + torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd + ) + else: + inv_freq = self.inv_freq.float().to(device) + t = torch.arange(seq_len, device=device, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached[:, :seq_len].to(dtype=dtype), self._sin_cached[:, :seq_len].to(dtype=dtype) + + +def apply_rotary_emb(x, cos, sin, rope_dims=0): + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * -sin + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len, yarn=True, + attn_out_gate=False, attn_out_gate_src="proj", gate_window=12, + gated_attn=False, gated_attn_init_std=0.01, + sparse_attn_gate=False, sparse_attn_gate_init_std=0.0, sparse_attn_gate_scale=1.0, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + if int(attn_out_gate) + int(gated_attn) + int(sparse_attn_gate) > 1: + raise ValueError( + "attn_out_gate, gated_attn, and sparse_attn_gate are mutually exclusive" + ) + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + self.q_gain = nn.Parameter( + torch.full((num_heads,), qk_gain_init, dtype=torch.float32) + ) + self.rope_dims = 0 + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len, yarn=yarn) + self.use_xsa = False + # AttnOutGate (PR #1667 MarioPaerle): per-head multiplicative gate on attention + # output. CastedLinear so restore_fp32_params casts back to fp32 for GPTQ. + # _zero_init -> 2*sigmoid(0)=1 -> transparent at init. + self.attn_out_gate = attn_out_gate + self.attn_out_gate_src = attn_out_gate_src + self.gate_window = gate_window + if attn_out_gate: + self.attn_gate_proj = CastedLinear(gate_window, num_heads, bias=False) + self.attn_gate_proj._zero_init = True + # Gated Attention (arXiv:2505.06708, Qwen, NeurIPS 2025). Per-head sigmoid + # gate on SDPA output, BEFORE out_proj. Gate projection W_g: (num_heads, dim). + # Name "attn_gate_w" contains "attn_gate" substring so it matches + # CONTROL_TENSOR_NAME_PATTERNS and routes to the scalar AdamW group. + # fp32 Parameter -> restore_fp32_params path covers it via the ndim<2 OR + # name-pattern check (name matches "attn_gate"). Cast to x.dtype on use. + self.gated_attn = gated_attn + if gated_attn: + W = torch.empty(num_heads, dim, dtype=torch.float32) + nn.init.normal_(W, mean=0.0, std=gated_attn_init_std) + self.attn_gate_w = nn.Parameter(W) + # Sparse attention head-output gate (modded-nanogpt style). Keeps dense SDPA + # and only narrows the gate input to the first gate_window residual dims. + # W_g: (num_heads, gate_window). y_{t,h} <- sigmoid(scale * W_g_h @ x_t[:gate_window]) * y_{t,h}. + # Shares attn_gate_w name with dense GatedAttn so the quant routing + # (CONTROL_TENSOR_NAME_PATTERNS / attn_gate_w int8 passthrough) is unchanged. + self.sparse_attn_gate = sparse_attn_gate + self.sparse_attn_gate_scale = sparse_attn_gate_scale + if sparse_attn_gate: + W = torch.empty(num_heads, gate_window, dtype=torch.float32) + if sparse_attn_gate_init_std > 0: + nn.init.normal_(W, mean=0.0, std=sparse_attn_gate_init_std) + else: + nn.init.zeros_(W) + self.attn_gate_w = nn.Parameter(W) + + def _xsa_efficient(self, y, v): + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x, q_w, k_w, v_w, out_w, cu_seqlens=None, max_seqlen=0): + bsz, seqlen, dim = x.shape + # q_raw kept around as a tap point for attn_out_gate_src='q' (post-projection, + # pre-reshape, pre-RoPE). + q_raw = F.linear(x, q_w.to(x.dtype)) + q = q_raw.reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = F.linear(x, v_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if cu_seqlens is not None: + y = flash_attn_varlen_func( + q[0], + k[0], + v[0], + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + causal=True, + window_size=(-1, -1), + )[None] + else: + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + # AttnOutGate inlined (PR #1667). Inline + .contiguous() barrier so torch.compile + # fullgraph=True is happy (this avoids the @torch.compiler.disable trap that + # crashed gates v3). Per-head gate on (B,T,H,D) tensor: g shape [B,T,H], broadcast + # over D via [..., None]. zero-init weight -> 2*sigmoid(0)=1 -> transparent. + if self.attn_out_gate: + gate_src = q_raw if self.attn_out_gate_src == "q" else x + gate_in = gate_src[..., : self.gate_window].contiguous() + g = 2.0 * torch.sigmoid(self.attn_gate_proj(gate_in)) + y = y * g[..., None] + # Gated Attention (arXiv:2505.06708 G1). Inline + .contiguous() barrier so + # torch.compile fullgraph=True is happy. Per-head gate on (B,T,H,D): g shape + # [B,T,H], broadcast over D via [..., None]. Paper: g = sigmoid(x @ W_g.T) + # where W_g: (H, dim). .to(x.dtype) on fp32 param before broadcast with bf16. + if self.gated_attn: + x_c = x.contiguous() + g = torch.sigmoid(F.linear(x_c, self.attn_gate_w.to(x.dtype))) + y = y * g[..., None] + # Sparse head-output gate: narrower (gate_window) input, same shape g as GatedAttn. + if self.sparse_attn_gate: + gate_in = x[..., : self.gate_window].contiguous() + g = torch.sigmoid( + self.sparse_attn_gate_scale + * F.linear(gate_in, self.attn_gate_w.to(x.dtype)) + ) + y = y * g[..., None] + y = y.reshape(bsz, seqlen, dim) + self._last_proj_input = y.detach() if getattr(self, "_calib", False) else None + return F.linear(y, out_w.to(x.dtype)) + + +class MLP(nn.Module): + def __init__(self, dim, mlp_mult): + super().__init__() + self.use_fused = True + + def forward(self, x, up_w, down_w): + if self.training and self.use_fused: + return FusedLeakyReLUSquareMLP(x, up_w.to(x.dtype), down_w.to(x.dtype)) + hidden = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=0.5).square() + self._last_down_input = hidden.detach() if getattr(self, "_calib", False) else None + return F.linear(hidden, down_w.to(x.dtype)) + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + train_seq_len, + layer_idx=0, + ln_scale=False, + yarn=True, + attn_out_gate=False, + attn_out_gate_src="proj", + gate_window=12, + gated_attn=False, + gated_attn_init_std=0.01, + sparse_attn_gate=False, + sparse_attn_gate_init_std=0.0, + sparse_attn_gate_scale=1.0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len, yarn=yarn, + attn_out_gate=attn_out_gate, attn_out_gate_src=attn_out_gate_src, gate_window=gate_window, + gated_attn=gated_attn, gated_attn_init_std=gated_attn_init_std, + sparse_attn_gate=sparse_attn_gate, + sparse_attn_gate_init_std=sparse_attn_gate_init_std, + sparse_attn_gate_scale=sparse_attn_gate_scale, + ) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter( + torch.stack((torch.ones(dim), torch.zeros(dim))).float() + ) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=None, max_seqlen=0): + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn( + self.attn_norm(x_in) * self.ln_scale_factor, + q_w, k_w, v_w, out_w, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[ + None, None, : + ] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) + return x_out + +class GPT(nn.Module): + def __init__(self, h): + super().__init__() + if h.logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {h.logit_softcap}") + self.tie_embeddings = h.tie_embeddings + self.tied_embed_init_std = h.tied_embed_init_std + self.logit_softcap = h.logit_softcap + self.asym_logit_enabled = h.asym_logit_rescale + if self.asym_logit_enabled: + self.softcap_pos = nn.Parameter(torch.tensor(h.logit_softcap)) + self.softcap_neg = nn.Parameter(torch.tensor(h.logit_softcap)) + self.fused_ce_enabled = bool(h.fused_ce_enabled) + self.tok_emb = nn.Embedding(h.vocab_size, h.model_dim) + self.num_layers = h.num_layers + head_dim = h.model_dim // h.num_heads + kv_dim = h.num_kv_heads * head_dim + hidden_dim = int(h.mlp_mult * h.model_dim) + self.qo_bank = nn.Parameter(torch.empty(2 * h.num_layers, h.model_dim, h.model_dim)) + self.kv_bank = nn.Parameter(torch.empty(2 * h.num_layers, kv_dim, h.model_dim)) + self.mlp_up_bank = nn.Parameter(torch.empty(h.num_layers, hidden_dim, h.model_dim)) + self.mlp_down_bank = nn.Parameter(torch.empty(h.num_layers, h.model_dim, hidden_dim)) + self.num_encoder_layers = h.num_layers // 2 + self.num_decoder_layers = h.num_layers - self.num_encoder_layers + self.blocks = nn.ModuleList( + [ + Block( + h.model_dim, + h.num_heads, + h.num_kv_heads, + h.mlp_mult, + h.rope_base, + # Layer 1: per-layer QK-Gain init schedule. Use schedule[i] if provided, + # else uniform qk_gain_init. Schedule is initialization only — q_gain + # stays trainable and diverges from this starting point during training. + (h.qk_gain_schedule[i] if h.qk_gain_schedule and i < len(h.qk_gain_schedule) + else h.qk_gain_init), + h.train_seq_len, + layer_idx=i, + ln_scale=h.ln_scale, + yarn=h.rope_yarn, + attn_out_gate=h.attn_out_gate_enabled, + attn_out_gate_src=h.attn_out_gate_src, + gate_window=h.gate_window, + gated_attn=h.gated_attn_enabled, + gated_attn_init_std=h.gated_attn_init_std, + sparse_attn_gate=h.sparse_attn_gate_enabled, + sparse_attn_gate_init_std=h.sparse_attn_gate_init_std, + sparse_attn_gate_scale=h.sparse_attn_gate_scale, + ) + for i in range(h.num_layers) + ] + ) + if h.rope_dims > 0: + head_dim = h.model_dim // h.num_heads + for block in self.blocks: + block.attn.rope_dims = h.rope_dims + block.attn.rotary = Rotary( + head_dim, + base=h.rope_base, + train_seq_len=h.train_seq_len, + rope_dims=h.rope_dims, + yarn=h.rope_yarn, + ) + self.final_norm = RMSNorm() + self.lm_head = ( + None + if h.tie_embeddings + else CastedLinear(h.model_dim, h.vocab_size, bias=False) + ) + if self.lm_head is not None: + self.lm_head._zero_init = True + if h.xsa_last_n > 0: + for i in range(max(0, h.num_layers - h.xsa_last_n), h.num_layers): + self.blocks[i].attn.use_xsa = True + self.looping_active = False + if h.num_loops > 0: + loop_seg = list(range(h.loop_start, h.loop_end + 1)) + all_indices = list(range(h.loop_start)) + for _ in range(h.num_loops + 1): + all_indices.extend(loop_seg) + all_indices.extend(range(h.loop_end + 1, h.num_layers)) + num_enc = len(all_indices) // 2 + self.encoder_indices = all_indices[:num_enc] + self.decoder_indices = all_indices[num_enc:] + else: + self.encoder_indices = list(range(self.num_encoder_layers)) + self.decoder_indices = list(range(self.num_encoder_layers, h.num_layers)) + self.num_skip_weights = min( + len(self.encoder_indices), len(self.decoder_indices) + ) + self.skip_weights = nn.Parameter( + torch.ones(self.num_skip_weights, h.model_dim, dtype=torch.float32) + ) + self.skip_gates = ( + nn.Parameter( + torch.zeros(self.num_skip_weights, h.model_dim, dtype=torch.float32) + ) + if h.skip_gates_enabled + else None + ) + self.parallel_start_layer = h.parallel_start_layer + self.parallel_final_lane = h.parallel_final_lane.lower() + self.parallel_post_lambdas = nn.Parameter( + torch.ones(h.num_layers, 2, 2, dtype=torch.float32) + ) + self.parallel_resid_lambdas = nn.Parameter( + torch.full((h.num_layers, 2), 1.1, dtype=torch.float32) + ) + # SmearGate (PR #1667 / modded-nanogpt @classiclarryd): + # x_t <- x_t + lam * sigmoid(W * x_t[:gate_window]) * x_{t-1}. + # Per-token forward-1 smear of the embedding lane. W zero-init + lam=0 -> + # transparent at init. Uses CastedLinear so restore_fp32_params handles dtype. + self.smear_gate_enabled = h.smear_gate_enabled + if self.smear_gate_enabled: + self.smear_window = h.gate_window + self.smear_gate = CastedLinear(self.smear_window, 1, bias=False) + self.smear_gate._zero_init = True + self.smear_lambda = nn.Parameter(torch.zeros(1, dtype=torch.float32)) + self._init_weights() + + def _init_weights(self): + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + n = self.num_layers + proj_scale = 1.0 / math.sqrt(2 * n) + for i in range(n): + nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) + nn.init.zeros_(self.qo_bank.data[n + i]) + self.qo_bank.data[n + i].mul_(proj_scale) + nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) + nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) + for i in range(n): + nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) + nn.init.zeros_(self.mlp_down_bank.data[i]) + self.mlp_down_bank.data[i].mul_(proj_scale) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif ( + module.weight.ndim == 2 + and module.weight.shape[0] >= 64 + and module.weight.shape[1] >= 64 + ): + nn.init.orthogonal_(module.weight, gain=1.0) + + def _bank_weights(self, i): + n = self.num_layers + return ( + self.qo_bank[i], + self.kv_bank[i], + self.kv_bank[n + i], + self.qo_bank[n + i], + self.mlp_up_bank[i], + self.mlp_down_bank[i], + ) + + def _parallel_block( + self, block_idx, lane0, lane1, x0, + q_w, k_w, v_w, out_w, up_w, down_w, + cu_seqlens=None, max_seqlen=0, + ): + block = self.blocks[block_idx] + mix = block.resid_mix.to(dtype=lane0.dtype) + attn_read = mix[0][None, None, :] * lane0 + mix[1][None, None, :] * x0 + attn_out = block.attn( + block.attn_norm(attn_read) * block.ln_scale_factor, + q_w, k_w, v_w, out_w, + cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + ) + attn_out = block.attn_scale.to(dtype=attn_out.dtype)[None, None, :] * attn_out + mlp_read = lane1 + mlp_out = block.mlp_scale.to(dtype=lane1.dtype)[None, None, :] * block.mlp( + block.mlp_norm(mlp_read) * block.ln_scale_factor, up_w, down_w + ) + attn_resid = self.parallel_resid_lambdas[block_idx, 0].to(dtype=lane0.dtype) + attn_post = self.parallel_post_lambdas[block_idx, 0].to(dtype=lane0.dtype) + mlp_resid = self.parallel_resid_lambdas[block_idx, 1].to(dtype=lane0.dtype) + mlp_post = self.parallel_post_lambdas[block_idx, 1].to(dtype=lane0.dtype) + lane0 = attn_resid * lane0 + attn_post[0] * attn_out + mlp_post[0] * mlp_out + lane1 = mlp_resid * lane1 + attn_post[1] * attn_out + mlp_post[1] * mlp_out + return lane0, lane1 + + def _final_parallel_hidden(self, lane0, lane1): + if self.parallel_final_lane == "mlp": + return lane1 + if self.parallel_final_lane == "attn": + return lane0 + return 0.5 * (lane0 + lane1) + + def _forward_hidden(self, input_ids, cu_seqlens=None, max_seqlen=0): + """Run the encoder/decoder stack to the final RMSNorm; returns pre-projection hidden. + Shared by eval (softcap+projection via forward_logits) and train (fused CE path).""" + x = self.tok_emb(input_ids) + # SmearGate (PR #1667). Inline gate compute with .contiguous() on the slice fed + # to the projection so torch.compile fullgraph is happy. lam=0 + W=0 -> identity + # at init. This block runs unconditionally on the smear path; the cat keeps + # position 0 untouched so causality holds. + if self.smear_gate_enabled: + sl = self.smear_lambda.to(dtype=x.dtype) + gate_in = x[:, 1:, : self.smear_window].contiguous() + g = sl * torch.sigmoid(self.smear_gate(gate_in)) + if os.environ.get("SMEAR_GATE_BOS_FIX", "0") == "1": + not_bos = (input_ids[:, 1:] != BOS_ID).to(x.dtype).unsqueeze(-1) + x = torch.cat([x[:, :1], x[:, 1:] + g * x[:, :-1] * not_bos], dim=1) + else: + x = torch.cat([x[:, :1], x[:, 1:] + g * x[:, :-1]], dim=1) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips = [] + enc_iter = ( + self.encoder_indices + if self.looping_active + else range(self.num_encoder_layers) + ) + dec_iter = ( + self.decoder_indices + if self.looping_active + else range( + self.num_encoder_layers, + self.num_encoder_layers + self.num_decoder_layers, + ) + ) + for i in enc_iter: + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + x = self.blocks[i](x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + skips.append(x) + psl = self.parallel_start_layer + lane0 = None + lane1 = None + for skip_idx, i in enumerate(dec_iter): + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + if i >= psl and psl > 0: + if lane0 is None: + lane0 = x + lane1 = x + if skip_idx < self.num_skip_weights and skips: + skip = skips.pop() + w = self.skip_weights[skip_idx].to(dtype=lane0.dtype)[None, None, :] + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=lane0.dtype))[None, None, :] + lane0 = torch.lerp(w * skip, lane0, g) + else: + lane0 = lane0 + w * skip + lane0, lane1 = self._parallel_block( + i, lane0, lane1, x0, q_w, k_w, v_w, out_w, up_w, down_w, + cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + ) + else: + if skip_idx < self.num_skip_weights and skips: + scaled_skip = ( + self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] + * skips.pop() + ) + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + x = self.blocks[i](x, x0, q_w, k_w, v_w, out_w, up_w, down_w, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + if lane0 is not None: + x = self._final_parallel_hidden(lane0, lane1) + x = self.final_norm(x) + return x + + def _project_logits(self, hidden): + if self.tie_embeddings: + return F.linear(hidden, self.tok_emb.weight) + return self.lm_head(hidden) + + def _apply_asym_softcap(self, logits): + """Asymmetric softcap: independent pos/neg learned scalars (PR #1923). + Init: softcap_pos == softcap_neg == logit_softcap → identical to scalar at step 0.""" + sp = self.softcap_pos.to(logits.dtype) + sn = self.softcap_neg.to(logits.dtype) + return torch.where(logits >= 0, + sp * torch.tanh(logits / sp), + sn * torch.tanh(logits / sn)) + + def forward_logits(self, input_ids, cu_seqlens=None, max_seqlen=0): + hidden = self._forward_hidden(input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + logits_proj = self._project_logits(hidden) + if self.asym_logit_enabled: + return self._apply_asym_softcap(logits_proj) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward(self, input_ids, target_ids, cu_seqlens=None, max_seqlen=0): + hidden = self._forward_hidden(input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + logits_proj = self._project_logits(hidden) + flat_targets = target_ids.reshape(-1) + # Fused softcapped-CE kernel (training path only). Applies softcap inside the + # Triton kernel; takes pre-softcap logits_proj. Non-fused path matches stock + # PR-1736 numerics exactly (softcap in fp32, then F.cross_entropy on fp32). + if self.fused_ce_enabled: + return softcapped_cross_entropy( + logits_proj.reshape(-1, logits_proj.size(-1)), + flat_targets, + self.logit_softcap, + reduction="mean", + ) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + flat_targets, + reduction="mean", + ) + + def forward_ttt(self, input_ids, target_ids, lora, hint_ids=None): + x = self.tok_emb(input_ids) + # SmearGate on the TTT path — same inline compute as forward_logits. + if self.smear_gate_enabled: + sl = self.smear_lambda.to(dtype=x.dtype) + gate_in = x[:, 1:, : self.smear_window].contiguous() + g = sl * torch.sigmoid(self.smear_gate(gate_in)) + if os.environ.get("SMEAR_GATE_BOS_FIX", "0") == "1": + not_bos = (input_ids[:, 1:] != BOS_ID).to(x.dtype).unsqueeze(-1) + x = torch.cat([x[:, :1], x[:, 1:] + g * x[:, :-1] * not_bos], dim=1) + else: + x = torch.cat([x[:, :1], x[:, 1:] + g * x[:, :-1]], dim=1) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips = [] + enc_iter = ( + self.encoder_indices + if self.looping_active + else list(range(self.num_encoder_layers)) + ) + dec_iter = ( + self.decoder_indices + if self.looping_active + else list( + range( + self.num_encoder_layers, + self.num_encoder_layers + self.num_decoder_layers, + ) + ) + ) + slot = 0 + for i in enc_iter: + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + x = self._block_with_lora(self.blocks[i], x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w) + slot += 1 + skips.append(x) + psl = self.parallel_start_layer + lane0 = None + lane1 = None + for skip_idx, i in enumerate(dec_iter): + q_w, k_w, v_w, out_w, up_w, down_w = self._bank_weights(i) + if i >= psl and psl > 0: + if lane0 is None: + lane0 = x + lane1 = x + if skip_idx < self.num_skip_weights and skips: + skip = skips.pop() + w = self.skip_weights[skip_idx].to(dtype=lane0.dtype)[None, None, :] + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=lane0.dtype))[None, None, :] + lane0 = torch.lerp(w * skip, lane0, g) + else: + lane0 = lane0 + w * skip + lane0, lane1 = self._parallel_block_with_lora( + i, lane0, lane1, x0, lora, slot, + q_w, k_w, v_w, out_w, up_w, down_w, + ) + else: + if skip_idx < self.num_skip_weights and skips: + scaled_skip = ( + self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] + * skips.pop() + ) + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + x = self._block_with_lora(self.blocks[i], x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w) + slot += 1 + if lane0 is not None: + x = self._final_parallel_hidden(lane0, lane1) + x = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = logits + lora.lm_head_lora(x) + if self.asym_logit_enabled: + logits = self._apply_asym_softcap(logits) + else: + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + bsz, sl, V = logits.shape + if hint_ids is None: + return F.cross_entropy( + logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none" + ).reshape(bsz, sl) + # PR #1145 tilt branch: return (per_tok_loss, log_q_hint) for scoring. + # TTT training backward uses requires_grad path (plain log_softmax); scoring + # uses Triton fused kernel (no autograd needed, saves memory + time). + if logits.requires_grad: + ls = F.log_softmax(logits.float(), dim=-1) + log_p_y = ls.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1) + log_q_h = ls.gather(-1, hint_ids.clamp(min=0).unsqueeze(-1)).squeeze(-1) + return -log_p_y, log_q_h + log_p_y, log_q_h = fused_log_softmax_dual_gather( + logits, target_ids, hint_ids.clamp(min=0) + ) + return -log_p_y, log_q_h + + def _block_with_lora(self, block, x, x0, lora, slot, q_w, k_w, v_w, out_w, up_w, down_w): + mix = block.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = block.attn_norm(x_in) * block.ln_scale_factor + attn = block.attn + bsz, seqlen, dim = n.shape + # Keep raw Q for AttnOutGate src='q' (matches forward path semantics). + q_raw = F.linear(n, q_w.to(n.dtype)) + lora.q_loras[slot](n) + q = q_raw.reshape(bsz, seqlen, attn.num_heads, attn.head_dim) + k = F.linear(n, k_w.to(n.dtype)) + if lora.k_loras is not None: + k = k + lora.k_loras[slot](n) + k = k.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + v = (F.linear(n, v_w.to(n.dtype)) + lora.v_loras[slot](n)).reshape( + bsz, seqlen, attn.num_kv_heads, attn.head_dim + ) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = attn.rotary(seqlen, n.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, attn.rope_dims) + k = apply_rotary_emb(k, cos, sin, attn.rope_dims) + q = q * attn.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if attn.use_xsa: + y = attn._xsa_efficient(y, v) + # AttnOutGate (TTT path) — inline + .contiguous() barrier, same as the eval path. + if attn.attn_out_gate: + gate_src = q_raw if attn.attn_out_gate_src == "q" else n + gate_in = gate_src[..., : attn.gate_window].contiguous() + g = 2.0 * torch.sigmoid(attn.attn_gate_proj(gate_in)) + y = y * g[..., None] + # Gated Attention (TTT path). Gate input is n (post-norm block input), same + # as eval path. .to(n.dtype) on fp32 param before bf16 broadcast. + if attn.gated_attn: + n_c = n.contiguous() + g = torch.sigmoid(F.linear(n_c, attn.attn_gate_w.to(n.dtype))) + y = y * g[..., None] + # Sparse attention head-output gate (TTT path) — must match the eval path in + # forward() exactly, else training (which applied the gate) and TTT eval (which + # skipped it) produce mismatched representations and catastrophic BPB regression. + if attn.sparse_attn_gate: + gate_in = n[..., : attn.gate_window].contiguous() + g = torch.sigmoid( + attn.sparse_attn_gate_scale + * F.linear(gate_in, attn.attn_gate_w.to(n.dtype)) + ) + y = y * g[..., None] + y = y.reshape(bsz, seqlen, dim) + attn_out = F.linear(y, out_w.to(n.dtype)) + if lora.o_loras is not None: + attn_out = attn_out + lora.o_loras[slot](n) + x_out = x_in + block.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + mlp_n = block.mlp_norm(x_out) * block.ln_scale_factor + mlp_out = block.mlp(mlp_n, up_w, down_w) + if lora.mlp_loras is not None: + mlp_out = mlp_out + lora.mlp_loras[slot](mlp_n) + x_out = x_out + block.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * mlp_out + return x_out + + def _parallel_block_with_lora( + self, block_idx, lane0, lane1, x0, lora, slot, + q_w, k_w, v_w, out_w, up_w, down_w, + ): + block = self.blocks[block_idx] + mix = block.resid_mix.to(dtype=lane0.dtype) + attn_read = mix[0][None, None, :] * lane0 + mix[1][None, None, :] * x0 + n = block.attn_norm(attn_read) * block.ln_scale_factor + attn = block.attn + bsz, seqlen, dim = n.shape + q_raw = F.linear(n, q_w.to(n.dtype)) + lora.q_loras[slot](n) + q = q_raw.reshape(bsz, seqlen, attn.num_heads, attn.head_dim) + k = F.linear(n, k_w.to(n.dtype)) + if lora.k_loras is not None: + k = k + lora.k_loras[slot](n) + k = k.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + v = (F.linear(n, v_w.to(n.dtype)) + lora.v_loras[slot](n)).reshape( + bsz, seqlen, attn.num_kv_heads, attn.head_dim + ) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = attn.rotary(seqlen, n.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, attn.rope_dims) + k = apply_rotary_emb(k, cos, sin, attn.rope_dims) + q = q * attn.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if attn.use_xsa: + y = attn._xsa_efficient(y, v) + # AttnOutGate (TTT parallel path) — inline + .contiguous() barrier. + if attn.attn_out_gate: + gate_src = q_raw if attn.attn_out_gate_src == "q" else n + gate_in = gate_src[..., : attn.gate_window].contiguous() + g = 2.0 * torch.sigmoid(attn.attn_gate_proj(gate_in)) + y = y * g[..., None] + # Gated Attention (TTT parallel path). Gate input is n (post-norm block input). + if attn.gated_attn: + n_c = n.contiguous() + g = torch.sigmoid(F.linear(n_c, attn.attn_gate_w.to(n.dtype))) + y = y * g[..., None] + # Sparse attention head-output gate (TTT parallel path) — must match the + # eval path in forward() to keep train/eval semantics in sync. + if attn.sparse_attn_gate: + gate_in = n[..., : attn.gate_window].contiguous() + g = torch.sigmoid( + attn.sparse_attn_gate_scale + * F.linear(gate_in, attn.attn_gate_w.to(n.dtype)) + ) + y = y * g[..., None] + y = y.reshape(bsz, seqlen, dim) + attn_out = F.linear(y, out_w.to(n.dtype)) + if lora.o_loras is not None: + attn_out = attn_out + lora.o_loras[slot](n) + attn_out = block.attn_scale.to(dtype=attn_out.dtype)[None, None, :] * attn_out + mlp_read = lane1 + mlp_n = block.mlp_norm(mlp_read) * block.ln_scale_factor + mlp_out = block.mlp(mlp_n, up_w, down_w) + if lora.mlp_loras is not None: + mlp_out = mlp_out + lora.mlp_loras[slot](mlp_n) + mlp_out = block.mlp_scale.to(dtype=lane1.dtype)[None, None, :] * mlp_out + attn_resid = self.parallel_resid_lambdas[block_idx, 0].to(dtype=lane0.dtype) + attn_post = self.parallel_post_lambdas[block_idx, 0].to(dtype=lane0.dtype) + mlp_resid = self.parallel_resid_lambdas[block_idx, 1].to(dtype=lane0.dtype) + mlp_post = self.parallel_post_lambdas[block_idx, 1].to(dtype=lane0.dtype) + lane0 = attn_resid * lane0 + attn_post[0] * attn_out + mlp_post[0] * mlp_out + lane1 = mlp_resid * lane1 + attn_post[1] * attn_out + mlp_post[1] * mlp_out + return lane0, lane1 + + +class BatchedLinearLoRA(nn.Module): + # PR-1767: rank-scaled output (alpha/rank), like standard LoRA. Decouples + # effective magnitude from rank so changing rank does not change LR scale. + _ALPHA = float(os.environ.get("TTT_LORA_ALPHA", "144")) + # PR-1767: optionally keep A warm across per-doc resets (only B is zeroed). + # Accumulates useful feature directions across documents within a TTT phase. + _WARM_START_A = bool(int(os.environ.get("TTT_WARM_START_A", "1"))) + + def __init__(self, bsz, in_features, out_features, rank): + super().__init__() + self._bound = 1.0 / math.sqrt(in_features) + self._scale = self._ALPHA / rank + self.A = nn.Parameter( + torch.empty(bsz, rank, in_features).uniform_(-self._bound, self._bound) + ) + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) + + def reset(self): + with torch.no_grad(): + if not self._WARM_START_A: + self.A.uniform_(-self._bound, self._bound) + self.B.zero_() + + def forward(self, x): + return ((x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2)) * self._scale + + +class BatchedTTTLoRA(nn.Module): + def __init__(self, bsz, model, rank, k_lora=True, mlp_lora=True, o_lora=True): + super().__init__() + self.bsz = bsz + dim = model.qo_bank.shape[-1] + vocab = model.tok_emb.num_embeddings + if getattr(model, "looping_active", False): + num_slots = len(model.encoder_indices) + len(model.decoder_indices) + else: + num_slots = len(model.blocks) + kv_dim = model.blocks[0].attn.num_kv_heads * ( + dim // model.blocks[0].attn.num_heads + ) + embed_dim = model.tok_emb.embedding_dim + self.lm_head_lora = BatchedLinearLoRA(bsz, embed_dim, vocab, rank) + self.q_loras = nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + self.v_loras = nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, kv_dim, rank) for _ in range(num_slots)] + ) + self.k_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, kv_dim, rank) for _ in range(num_slots)] + ) + if k_lora + else None + ) + self.mlp_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + if mlp_lora + else None + ) + self.o_loras = ( + nn.ModuleList( + [BatchedLinearLoRA(bsz, dim, dim, rank) for _ in range(num_slots)] + ) + if o_lora + else None + ) + + def reset(self): + with torch.no_grad(): + self.lm_head_lora.reset() + for loras in [self.q_loras, self.v_loras, self.k_loras, + self.mlp_loras, self.o_loras]: + if loras is not None: + for lora in loras: + lora.reset() + + +# Polar Express per-iteration minimax Newton-Schulz coefficients (PR #1344). +# Replaces the fixed (3.4445, -4.775, 2.0315) coefficients of stock Muon. +# Applied at backend_steps=5 — taking more than 5 iterations from this list +# falls back to the final (converged) tuple via the slice guard below. +_PE_COEFFS = ( + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323), +) + + +@torch.compile +def zeropower_via_newtonschulz5(G, steps=10, eps=1e-07): + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + coeffs = _PE_COEFFS[:steps] if steps <= len(_PE_COEFFS) else _PE_COEFFS + for a, b, c in coeffs: + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + + +def _apply_huber_wd(w, lr, wd, delta): + """Layer 3: Huber weight decay in-place. + Gradient: w (|w|<=delta), delta*sign(w) (|w|>delta). + Transitions from quadratic (L2-like) to linear (L1-like) at |w|=delta. + Bounds the decay rate on outlier weights, preventing GPTQ-damaging over-suppression. + """ + wf = w.float() + abs_w = wf.abs() + grad = torch.where(abs_w <= delta, wf, delta * wf.sign()) + w.add_(grad.to(w.dtype), alpha=-lr * wd) + + +class Muon(torch.optim.Optimizer): + def __init__( + self, + params, + lr, + momentum, + backend_steps, + nesterov=True, + weight_decay=0.0, + row_normalize=False, + ): + super().__init__( + params, + dict( + lr=lr, + momentum=momentum, + backend_steps=backend_steps, + nesterov=nesterov, + weight_decay=weight_decay, + row_normalize=row_normalize, + ), + ) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + "p": p, + "B": B, + "padded_grad": torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + "shard": torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + "shard_mom": torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + "full_update": torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + "scale": max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + self._bank_meta.sort(key=lambda m: -m["p"].numel()) + self._built = True + + def launch_reduce_scatters(self): + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m["p"] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m["padded_grad"] + pg[: m["B"]].copy_(p.grad.bfloat16()) + if pg.shape[0] > m["B"]: + pg[m["B"] :].zero_() + fut = dist.reduce_scatter_tensor( + m["shard"], pg, op=dist.ReduceOp.AVG, async_op=True + ) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + if not self._built: + self._build() + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + row_normalize = group.get("row_normalize", False) + # Layer 3: read AdamHD Huber WD flags from param group + _huber_wd = group.get("huber_wd", False) + _huber_delta = group.get("huber_delta", 0.1) + prev_ag_handle = None + prev_m = None + sharded = self._distributed and hasattr(self, "_rs_futures") + for idx, m in enumerate(self._bank_meta): + p = m["p"] + if p.grad is None: + continue + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m["p"] + upd = prev_m["full_update"][: prev_m["B"]] + if wd > 0.0: + if _huber_wd: + _apply_huber_wd(pp.data, lr, wd, _huber_delta) + else: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m["scale"]) + if sharded and self._rs_futures[idx] is not None: + self._rs_futures[idx].wait() + g = m["shard"] + buf = m["shard_mom"] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + if row_normalize: + rn = update.float().norm(dim=-1, keepdim=True).clamp_min(1e-07) + update = update / rn.to(update.dtype) + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m["full_update"], update, async_op=True + ) + prev_m = m + else: + if wd > 0.0: + if _huber_wd: + _apply_huber_wd(p.data, lr, wd, _huber_delta) + else: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m["scale"]) + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m["p"] + upd = prev_m["full_update"][: prev_m["B"]] + if wd > 0.0: + if _huber_wd: + _apply_huber_wd(pp.data, lr, wd, _huber_delta) + else: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m["scale"]) + if hasattr(self, "_rs_futures"): + del self._rs_futures + return loss + + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates,parallel_post_lambdas,parallel_resid_lambdas,attn_gate_proj,attn_gate_w,smear_gate,smear_lambda", + ).split(",") + if pattern +) + + +PACKED_REPLICATED_GRAD_MAX_NUMEL = 1 << 15 + + +class Optimizers: + def __init__(self, h, base_model): + matrix_params = [ + base_model.qo_bank, + base_model.kv_bank, + base_model.mlp_up_bank, + base_model.mlp_down_bank, + ] + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [ + p + for (name, p) in block_named_params + if p.ndim < 2 + or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.skip_gates is not None and base_model.skip_gates.numel() > 0: + scalar_params.append(base_model.skip_gates) + if base_model.parallel_post_lambdas is not None: + scalar_params.append(base_model.parallel_post_lambdas) + if base_model.parallel_resid_lambdas is not None: + scalar_params.append(base_model.parallel_resid_lambdas) + # SmearGate params live on GPT root (not in .blocks), so add them by hand. + # Both are tiny (gate_window scalars + 1 lambda). Optimized via scalar Adam. + if getattr(base_model, "smear_gate_enabled", False): + scalar_params.append(base_model.smear_gate.weight) + scalar_params.append(base_model.smear_lambda) + token_lr = h.tied_embed_lr if h.tie_embeddings else h.embed_lr + tok_params = [ + {"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr} + ] + self.optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.embed_wd, + fused=True, + ) + self.optimizer_muon = Muon( + matrix_params, + lr=h.matrix_lr, + momentum=h.muon_momentum, + backend_steps=h.muon_backend_steps, + weight_decay=h.muon_wd, + row_normalize=h.muon_row_normalize, + ) + for group in self.optimizer_muon.param_groups: + group["base_lr"] = h.matrix_lr + # Layer 3: AdamHD Huber WD flags injected into each param group. + group["huber_wd"] = bool(getattr(h, "muon_huber_wd", False)) + group["huber_delta"] = float(getattr(h, "muon_huber_delta", 0.1)) + self.optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": h.scalar_lr, "base_lr": h.scalar_lr}], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.adam_wd, + fused=True, + ) + self.optimizers = [ + self.optimizer_tok, + self.optimizer_muon, + self.optimizer_scalar, + ] + self.replicated_params = list(tok_params[0]["params"]) + self.replicated_params.extend(scalar_params) + self.replicated_large_params = [] + self.replicated_packed_params = [] + for p in self.replicated_params: + if p.numel() <= PACKED_REPLICATED_GRAD_MAX_NUMEL: + self.replicated_packed_params.append(p) + else: + self.replicated_large_params.append(p) + + def __iter__(self): + return iter(self.optimizers) + + def zero_grad_all(self): + for opt in self.optimizers: + opt.zero_grad(set_to_none=True) + + def _all_reduce_packed_grads(self): + grads_by_key = collections.defaultdict(list) + for p in self.replicated_packed_params: + if p.grad is not None: + grads_by_key[(p.grad.device, p.grad.dtype)].append(p.grad) + for grads in grads_by_key.values(): + flat = torch.empty( + sum(g.numel() for g in grads), + device=grads[0].device, + dtype=grads[0].dtype, + ) + offset = 0 + for g in grads: + n = g.numel() + flat[offset : offset + n].copy_(g.contiguous().view(-1)) + offset += n + dist.all_reduce(flat, op=dist.ReduceOp.AVG) + offset = 0 + for g in grads: + n = g.numel() + g.copy_(flat[offset : offset + n].view_as(g)) + offset += n + + def step(self, distributed=False): + self.optimizer_muon.launch_reduce_scatters() + if distributed: + reduce_handles = [ + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG, async_op=True) + for p in self.replicated_large_params + if p.grad is not None + ] + self._all_reduce_packed_grads() + for handle in reduce_handles: + handle.wait() + self.optimizer_tok.step() + self.optimizer_scalar.step() + self.optimizer_muon.step() + self.zero_grad_all() + + +def restore_fp32_params(model): + for module in model.modules(): + if isinstance(module, CastedLinear): + module.float() + for name, param in model.named_parameters(): + if ( + param.ndim < 2 + or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ) and param.dtype != torch.float32: + param.data = param.data.float() + if hasattr(model, "qo_bank") and model.qo_bank is not None: + model.qo_bank.data = model.qo_bank.data.float() + model.kv_bank.data = model.kv_bank.data.float() + model.mlp_up_bank.data = model.mlp_up_bank.data.float() + model.mlp_down_bank.data = model.mlp_down_bank.data.float() + + +def collect_hessians(model, train_loader, h, device, n_calibration_batches=64): + hessians = {} + hooks = [] + for i, block in enumerate(model.blocks): + block.attn._calib = True + block.mlp._calib = True + block.mlp.use_fused = False + + def make_attn_hook(layer_idx): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + for suffix in ["c_q", "c_k", "c_v"]: + name = f"blocks.{layer_idx}.attn.{suffix}.weight" + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + y = module._last_proj_input + if y is not None: + y = y.float() + if y.ndim == 3: + y = y.reshape(-1, y.shape[-1]) + name = f"blocks.{layer_idx}.attn.proj.weight" + if name not in hessians: + hessians[name] = torch.zeros( + y.shape[1], y.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(y.T, y) + return hook_fn + + def make_mlp_hook(layer_idx): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + name = f"blocks.{layer_idx}.mlp.fc.weight" + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + h_act = module._last_down_input + if h_act is not None: + h_act = h_act.float() + if h_act.ndim == 3: + h_act = h_act.reshape(-1, h_act.shape[-1]) + name = f"blocks.{layer_idx}.mlp.proj.weight" + if name not in hessians: + hessians[name] = torch.zeros( + h_act.shape[1], h_act.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(h_act.T, h_act) + return hook_fn + + for i, block in enumerate(model.blocks): + hooks.append(block.attn.register_forward_hook(make_attn_hook(i))) + hooks.append(block.mlp.register_forward_hook(make_mlp_hook(i))) + + # Hessian hooks for embedding factorization projection layers + def make_linear_input_hook(weight_name): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if weight_name not in hessians: + hessians[weight_name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[weight_name].addmm_(x.T, x) + return hook_fn + + if model.tie_embeddings: + hook_module = model.final_norm + + def make_output_hook(name): + def hook_fn(module, inp, out): + x = out.detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + return hook_fn + + hooks.append( + hook_module.register_forward_hook(make_output_hook("tok_emb.weight")) + ) + model.eval() + with torch.no_grad(): + for _ in range(n_calibration_batches): + x, _ = train_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps) + model.forward_logits(x) + for hook in hooks: + hook.remove() + for i, block in enumerate(model.blocks): + block.attn._calib = False + block.mlp._calib = False + block.mlp.use_fused = True + for name in hessians: + hessians[name] = hessians[name].cpu() / n_calibration_batches + return hessians + + +def gptq_quantize_weight(w, H, clip_sigmas=3.0, clip_range=63, block_size=128): + W_orig = w.float().clone() + rows, cols = W_orig.shape + H = H.float().clone() + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + damp = 0.01 * H.diag().mean() + H.diagonal().add_(damp) + perm = torch.argsort(H.diag(), descending=True) + invperm = torch.argsort(perm) + W_perm = W_orig[:, perm].clone() + W_perm[:, dead[perm]] = 0 + H = H[perm][:, perm] + Hinv = torch.cholesky_inverse(torch.linalg.cholesky(H)) + Hinv = torch.linalg.cholesky(Hinv, upper=True) + row_std = W_orig.std(dim=1) + s = (clip_sigmas * row_std / clip_range).clamp_min(1e-10).to(torch.float16) + sf = s.float() + Q = torch.zeros(rows, cols, dtype=torch.int8) + W_work = W_perm.clone() + for i1 in range(0, cols, block_size): + i2 = min(i1 + block_size, cols) + W_block = W_work[:, i1:i2].clone() + Hinv_block = Hinv[i1:i2, i1:i2] + Err = torch.zeros(rows, i2 - i1) + for j in range(i2 - i1): + w_col = W_block[:, j] + d = Hinv_block[j, j] + q_col = torch.clamp(torch.round(w_col / sf), -clip_range, clip_range) + Q[:, i1 + j] = q_col.to(torch.int8) + err = (w_col - q_col.float() * sf) / d + Err[:, j] = err + W_block[:, j:] -= err.unsqueeze(1) * Hinv_block[j, j:].unsqueeze(0) + if i2 < cols: + W_work[:, i2:] -= Err @ Hinv[i1:i2, i2:] + return Q[:, invperm], s + + +def _quantize_gate_int8_row(w): + # Symmetric int8-per-row quantization for small gate tensors. w shape + # (R, C) -> (R,) scales in fp16, int8 values in [-127, 127]. Single scale + # per row keeps accuracy high while halving storage vs fp16. + W = w.float().contiguous() + row_max = W.abs().amax(dim=1).clamp_min(1e-10) + s = (row_max / 127.0).to(torch.float16) + sf = s.float().view(-1, 1) + q = torch.clamp(torch.round(W / sf), -127, 127).to(torch.int8) + return q, s + + +def _lqer_pack(A, B, bits): + rng = 2 ** (bits - 1) - 1 + sA = (A.abs().amax(dim=1).clamp_min(1e-10) / rng).to(torch.float16) + sB = (B.abs().amax(dim=1).clamp_min(1e-10) / rng).to(torch.float16) + qA = torch.clamp(torch.round(A / sA.float().view(-1, 1)), -rng, rng).to(torch.int8) + qB = torch.clamp(torch.round(B / sB.float().view(-1, 1)), -rng, rng).to(torch.int8) + return qA, sA, qB, sB + + +def _lqer_pack_asym(A, B, g=64): + # A: INT2 per-matrix scalar (signed [-2,1], scale = |A|max/1.5). + sA = (A.abs().amax().clamp_min(1e-10) / 1.5).to(torch.float16) + qA = torch.clamp(torch.round(A / sA.float()), -2, 1).to(torch.int8) + # B: INT4 groupwise g over flattened B (signed [-8,7], per-group scale). + Bf = B.reshape(-1, g) + Bmax = Bf.abs().amax(dim=-1, keepdim=True).clamp_min(1e-10) + sB = (Bmax / 7.5).to(torch.float16).reshape(-1) + qB = torch.clamp(torch.round(Bf / sB.float().reshape(-1, 1)), -8, 7).to( + torch.int8 + ).reshape(B.shape) + return qA, sA, qB, sB + + +def gptq_mixed_quantize(state_dict, hessians, h): + result = {} + meta = {} + quant_gate = bool(getattr(h, "gated_attn_quant_gate", False)) + lqer_on = bool(getattr(h, "lqer_enabled", False)) + lqer_cands = {} + for (name, tensor) in state_dict.items(): + t = tensor.detach().cpu().contiguous() + # Dedicated int8-per-row path for attn_gate_w (bypasses both GPTQ and + # fp16 passthrough). Applied BEFORE the numel<=65536 passthrough check + # so the gate tensor is routed here instead of to fp16. + if ( + quant_gate + and t.is_floating_point() + and t.ndim == 2 + and name.endswith(".attn_gate_w") + # Dense GatedAttn: (num_heads, dim) = (8, 512) = 4096. + # Sparse gate: (num_heads, gate_window) = (8, 12) = 96. + # Both need int8-per-row routing; the 1024 lower bound in stock + # PR-1736 presumed dense-only. Widen to catch both. + and 32 <= t.numel() <= 8192 + ): + gq, gs = _quantize_gate_int8_row(t) + result[name + ".gq"] = gq + result[name + ".gs"] = gs + meta[name] = "gate_int8_row" + continue + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough (float16)" + continue + if "tok_emb" in name: + cs = h.embed_clip_sigmas + elif ".mlp." in name: + cs = h.mlp_clip_sigmas + elif ".attn." in name: + cs = h.attn_clip_sigmas + else: + cs = h.matrix_clip_sigmas + bits = h.embed_bits if "tok_emb" in name else h.matrix_bits + clip_range = 2 ** (bits - 1) - 1 + ret = gptq_quantize_weight( + t, hessians[name], clip_sigmas=cs, clip_range=clip_range + ) + q, s = ret + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = f"gptq (int{bits})" + if lqer_on: + W_q = q.float() * s.float().view(-1, 1) + E = t.float() - W_q + lqer_cands[name] = (E, float(E.norm())) + if lqer_on and lqer_cands: + top = sorted(lqer_cands.items(), key=lambda kv: -kv[1][1])[: h.lqer_top_k] + asym_on = bool(getattr(h, "lqer_asym_enabled", False)) + asym_g = int(getattr(h, "lqer_asym_group", 64)) + for (name, (E, _)) in top: + U, S, Vh = torch.linalg.svd(E, full_matrices=False) + r = min(h.lqer_rank, S.numel()) + A = (U[:, :r] * S[:r]).contiguous() + B = Vh[:r, :].contiguous() + if asym_on and B.numel() % asym_g == 0: + qA, sA, qB, sB = _lqer_pack_asym(A, B, asym_g) + result[name + ".lqA_a"] = qA + result[name + ".lqAs_a"] = sA + result[name + ".lqB_a"] = qB + result[name + ".lqBs_a"] = sB + meta[name] = meta[name] + "+lqer_asym" + else: + qA, sA, qB, sB = _lqer_pack(A, B, h.lqer_factor_bits) + result[name + ".lqA"] = qA + result[name + ".lqAs"] = sA + result[name + ".lqB"] = qB + result[name + ".lqBs"] = sB + meta[name] = meta[name] + "+lqer" + categories = collections.defaultdict(set) + for (name, cat) in meta.items(): + short = re.sub("\\.\\d+$", "", re.sub("blocks\\.\\d+", "blocks", name)) + categories[cat].add(short) + log("Quantized weights:") + for cat in sorted(categories): + log(f" {cat}: {', '.join(sorted(categories[cat]))}") + return result, meta + +def dequantize_mixed(result, meta, template_sd): + out = {} + for (name, orig) in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if "passthrough" in info: + t = result[name] + if t.dtype == torch.float16 and orig_dtype in ( + torch.float32, + torch.bfloat16, + ): + t = t.to(orig_dtype) + out[name] = t + continue + if info == "gate_int8_row": + gq = result[name + ".gq"] + gs = result[name + ".gs"] + out[name] = (gq.float() * gs.float().view(-1, 1)).to(orig_dtype) + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + W = q.float() * s.float().view(q.shape[0], *[1] * (q.ndim - 1)) + else: + W = q.float() * float(s.item()) + if "lqer_asym" in info: + qA_t = result[name + ".lqA_a"] + sA_t = result[name + ".lqAs_a"] + qB_t = result[name + ".lqB_a"] + sB_t = result[name + ".lqBs_a"] + qA = qA_t.float() * float(sA_t) + g_sz = qB_t.numel() // sB_t.numel() + qB = (qB_t.reshape(-1, g_sz).float() * sB_t.float().view(-1, 1)).reshape( + qB_t.shape + ) + W = W + qA @ qB + elif "lqer" in info: + qA = result[name + ".lqA"].float() * result[name + ".lqAs"].float().view(-1, 1) + qB = result[name + ".lqB"].float() * result[name + ".lqBs"].float().view(-1, 1) + W = W + qA @ qB + out[name] = W.to(orig_dtype) + return out + + +_BSHF_MAGIC = b"BSHF" + + +def _byte_shuffle(data, stride=2): + if stride <= 1 or len(data) < stride: + return data + src = np.frombuffer(data, dtype=np.uint8) + n = len(src) + out = np.empty(n, dtype=np.uint8) + dest_off = 0 + for pos in range(stride): + chunk = src[pos::stride] + out[dest_off : dest_off + len(chunk)] = chunk + dest_off += len(chunk) + return _BSHF_MAGIC + bytes([stride]) + out.tobytes() + + +def _byte_unshuffle(data): + if len(data) < 5 or data[:4] != _BSHF_MAGIC: + return data + stride = data[4] + if stride < 2: + return data[5:] + payload = np.frombuffer(data, dtype=np.uint8, offset=5) + n = len(payload) + out = np.empty(n, dtype=np.uint8) + src_off = 0 + for pos in range(stride): + chunk_len = n // stride + (1 if pos < n % stride else 0) + out[pos::stride][:chunk_len] = payload[src_off : src_off + chunk_len] + src_off += chunk_len + return out.tobytes() + + +def _compress(data, compressor): + data = _byte_shuffle(data) + if compressor == "lzma": + return lzma.compress(data, preset=6) + elif compressor == "brotli": + import brotli + + return brotli.compress(data, quality=11) + raise ValueError(f"Unknown compressor: {compressor!r}") + + +def _decompress(data, compressor): + if compressor == "lzma": + raw = lzma.decompress(data) + elif compressor == "brotli": + import brotli + + raw = brotli.decompress(data) + else: + raise ValueError(f"Unknown compressor: {compressor!r}") + raw = _byte_unshuffle(raw) + return raw + + +# --------------------------------------------------------------------------- +# COMPRESSOR=pergroup — role-bucketed lrzip/ZPAQ compression +# +# Splits the quantized state dict into two sections before serialization: +# 1. Q section – the large int8 GPTQ weight tensors (.weight.q keys). +# Each tensor's rows are sorted by L1 nearest-neighbour similarity so +# adjacent rows are numerically close, then transposed for better run- +# length regularity. All sorted+transposed blobs are concatenated and +# compressed with lrzip ZPAQ (-z -L 9). If lrzip is absent the section +# falls back to brotli automatically — never crashes. +# 2. Remainder – scales, LQER factors, passthrough tensors, quant_meta. +# Serialized with torch.save + byte_shuffle + brotli-11 (same as the +# default brotli path). +# +# Frame format (little-endian): +# [4B] magic b"PGRP" +# [4B] uint32 version = 2 +# [4B] uint32 n_q_tensors +# Per Q tensor (sorted by name): +# [2B] uint16 name_len +# [name_len B] name (UTF-8) +# [4B] uint32 rows (original shape[0] before sort+transpose) +# [4B] uint32 cols (original shape[1] before sort+transpose) +# [rows*2 B] uint16 row-permutation indices +# [1B] uint8 q_method (0 = brotli fallback, 1 = lrzip) +# [4B] uint32 q_data_size +# [q_data_size B] compressed Q blob +# [4B] uint32 remainder_size +# [remainder_size B] brotli-compressed remainder +# --------------------------------------------------------------------------- + +import struct as _struct + +_PGRP_MAGIC = b"PGRP" +_PGRP_VERSION = 2 + +# Only the large int8 weight tensors go through the pergroup path. +_PGRP_Q_SUFFIXES = ( + ".mlp.fc.weight.q", + ".mlp.proj.weight.q", + ".attn.c_q.weight.q", + ".attn.proj.weight.q", + ".attn.c_k.weight.q", + ".attn.c_v.weight.q", + "tok_emb.weight.q", +) + + +def _similarity_sort_l1(W): + """Greedy L1 nearest-neighbour row sort. Returns uint16 permutation array. + + O(n²·cols) numpy – takes ~3-6 s on the largest tensors (2048 rows). + """ + n = W.shape[0] + if n <= 1: + return np.arange(n, dtype=np.uint16) + W16 = W.astype(np.int16) + used = np.zeros(n, dtype=bool) + order = np.empty(n, dtype=np.int32) + order[0] = 0 + used[0] = True + for i in range(1, n): + d = np.abs(W16 - W16[order[i - 1]]).sum(axis=1) + d[used] = 2 ** 30 + nxt = int(d.argmin()) + order[i] = nxt + used[nxt] = True + return order.astype(np.uint16) + + +def _lrzip_compress_bytes(data): + """Compress bytes with lrzip ZPAQ via temp files. Returns bytes or None.""" + import tempfile + in_fd, in_path = tempfile.mkstemp(suffix=".bin") + out_path = in_path + ".lrz" + try: + os.write(in_fd, data) + os.close(in_fd) + in_fd = -1 + r = subprocess.run( + ["lrzip", "-z", "-L", "9", "-o", out_path, in_path], + capture_output=True, timeout=300, + ) + if r.returncode == 0 and os.path.exists(out_path): + with open(out_path, "rb") as fh: + return fh.read() + log(f"pergroup:lrzip exit {r.returncode}: {r.stderr.decode()[:200]}") + except FileNotFoundError: + log("pergroup:lrzip not found — falling back to brotli for Q section") + except subprocess.TimeoutExpired: + log("pergroup:lrzip timed out — falling back to brotli for Q section") + except Exception as e: + log(f"pergroup:lrzip error ({e}) — falling back to brotli for Q section") + finally: + if in_fd != -1: + try: os.close(in_fd) + except OSError: pass + for p in (in_path, out_path): + try: os.unlink(p) + except OSError: pass + return None + + +def _lrzip_decompress_bytes(data): + """Decompress lrzip ZPAQ bytes via temp files.""" + import tempfile + lrz_fd, lrz_path = tempfile.mkstemp(suffix=".lrz") + # lrzip strips .lrz → output name is lrz_path[:-4] + out_path = lrz_path[:-4] + try: + os.write(lrz_fd, data) + os.close(lrz_fd) + lrz_fd = -1 + r = subprocess.run( + ["lrzip", "-d", "-o", out_path, lrz_path], + capture_output=True, timeout=300, + ) + if r.returncode != 0: + raise RuntimeError(f"lrzip -d failed (exit {r.returncode}): {r.stderr.decode()[:400]}") + with open(out_path, "rb") as fh: + return fh.read() + finally: + if lrz_fd != -1: + try: os.close(lrz_fd) + except OSError: pass + # lrzip -d removes the input .lrz by default; OSError caught if already gone + for p in (lrz_path, out_path): + try: os.unlink(p) + except OSError: pass + + +def _pack_pergroup(quant_result, quant_meta): + """Serialize quantized state dict using the PGRP frame format. + + Q tensors → similarity-sorted, transposed, lrzip ZPAQ (brotli fallback). + Everything else → torch.save + byte_shuffle + brotli-11. + """ + import brotli + + # --- split Q tensors from remainder --- + q_items = {} # name → int8 numpy array + remainder = {} # name → tensor (scales, LQER, passthrough) + for name, t in quant_result.items(): + if any(name.endswith(sfx) for sfx in _PGRP_Q_SUFFIXES): + q_items[name] = (t.numpy() if hasattr(t, "numpy") else np.asarray(t)) + else: + remainder[name] = t + + q_names = sorted(q_items.keys()) + + # --- similarity sort + transpose per Q tensor --- + q_perms = {} # name → uint16 perm array + q_shapes = {} # name → (rows, cols) + q_blobs = [] # sorted+transposed int8 bytes, concatenated later + + t_sort = time.perf_counter() + for name in q_names: + W = q_items[name] + assert W.ndim == 2, f"pergroup: Q tensor {name} is not 2D: {W.shape}" + rows, cols = W.shape + q_shapes[name] = (rows, cols) + perm = _similarity_sort_l1(W) + q_perms[name] = perm + # sort rows then transpose → (cols, rows); adjacent values more similar + W_st = W[perm.astype(np.int32)].T.astype(np.int8) + q_blobs.append(W_st.tobytes()) + log(f"pergroup:similarity sort done in {time.perf_counter()-t_sort:.1f}s ({len(q_names)} tensors)") + + q_data_raw = b"".join(q_blobs) + + # --- compress Q section (lrzip ZPAQ, brotli fallback) --- + q_compressed = _lrzip_compress_bytes(q_data_raw) + if q_compressed is not None: + q_method = 1 # lrzip + else: + q_compressed = brotli.compress(q_data_raw, quality=11) + q_method = 0 # brotli fallback + log( + f"pergroup:Q {len(q_data_raw)} raw → {len(q_compressed)} " + f"({'lrzip' if q_method else 'brotli'}) " + f"({100*len(q_compressed)/max(len(q_data_raw),1):.1f}%)" + ) + + # --- remainder section: torch.save + byte_shuffle + brotli --- + rem_buf = io.BytesIO() + torch.save({"w": remainder, "m": quant_meta}, rem_buf) + rem_compressed = brotli.compress(_byte_shuffle(rem_buf.getvalue()), quality=11) + log(f"pergroup:remainder {rem_buf.tell()} raw → {len(rem_compressed)} brotli") + + # --- assemble PGRP frame --- + out = io.BytesIO() + out.write(_PGRP_MAGIC) + out.write(_struct.pack("= 1, f"OptRot: W.shape[0]={n} must be power of 2" + Y = W.clone() + h_step = 1 + while h_step < n: + # Butterfly: view as (n_blocks, 2, h_step, ...) so dim-1 selects first/second half + # of each 2*h_step block. Pairs element i with element i+h_step within each block. + n_blocks = n // (h_step * 2) + Y2 = Y.view(n_blocks, 2, h_step, *Y.shape[1:]) + a = Y2[:, 0, ...].clone() + b = Y2[:, 1, ...].clone() + Y2[:, 0, ...] = a + b + Y2[:, 1, ...] = a - b + Y = Y2.view(n, *Y.shape[1:]) + h_step *= 2 + return Y / math.sqrt(n) + + +def _optrot_apply(sd, h): + """Layer 2: Apply per-layer Hadamard rotation to (W_up, W_down) and (W_v, W_o) pairs. + + Rotation R = FWHT / sqrt(n) is orthogonal and self-inverse (R^2 = I). + For the MLP: W_up' = R @ W_up, W_down' = W_down @ R. Product preserved: + W_down' @ W_up' = W_down @ R @ R @ W_up = W_down @ W_up. + For attention V→O (per kv-head, per query-head group): W_v' = R @ W_v per kv-head block; + W_o' = W_o @ R per corresponding query-head column block. Product preserved. + + Only applied to layers with hidden_dim and head_dim that are powers of 2 (all layers + in the default 11L 512d 8H/4KV architecture: hidden=2048, head_dim=64, both OK). + """ + num_layers = h.num_layers + num_heads = h.num_heads + num_kv_heads = h.num_kv_heads + model_dim = h.model_dim + head_dim = model_dim // num_heads + kv_group = num_heads // num_kv_heads + + for i in range(num_layers): + # --- MLP rotation --- + W_up = sd[f"blocks.{i}.mlp.fc.weight"].float() # (hidden_dim, model_dim) + W_down = sd[f"blocks.{i}.mlp.proj.weight"].float() # (model_dim, hidden_dim) + hidden_dim = W_up.shape[0] + if (hidden_dim & (hidden_dim - 1)) == 0: + # R @ W_up: apply FWHT to rows of W_up (along axis 0) + W_up_r = _fwht_along_0(W_up) + # W_down @ R = (R @ W_down.T).T: apply FWHT to rows of W_down.T, then transpose + W_down_r = _fwht_along_0(W_down.T).T + sd[f"blocks.{i}.mlp.fc.weight"] = W_up_r.to(sd[f"blocks.{i}.mlp.fc.weight"].dtype) + sd[f"blocks.{i}.mlp.proj.weight"] = W_down_r.to(sd[f"blocks.{i}.mlp.proj.weight"].dtype) + + # --- Attention V→O rotation (per kv-head, per query-head group) --- + W_v = sd[f"blocks.{i}.attn.c_v.weight"].float() # (kv_dim, model_dim) + W_o = sd[f"blocks.{i}.attn.proj.weight"].float() # (model_dim, model_dim) + kv_dim = W_v.shape[0] + if (head_dim & (head_dim - 1)) == 0: + # Reshape V to (num_kv_heads, head_dim, model_dim), rotate head_dim (axis 0 per head) + V = W_v.view(num_kv_heads, head_dim, model_dim) + V_r = torch.stack([_fwht_along_0(V[k]) for k in range(num_kv_heads)], dim=0) + sd[f"blocks.{i}.attn.c_v.weight"] = V_r.view(kv_dim, model_dim).to(W_v.dtype) + + # Reshape O to (model_dim, num_heads, head_dim). + # For each kv-head k: the corresponding query-head range uses the SAME V rotation. + # Apply R to the head_dim columns of O for each query-head in the group. + O = W_o.view(model_dim, num_heads, head_dim) + for k in range(num_kv_heads): + for g in range(kv_group): + h_idx = k * kv_group + g + # O_col block: (model_dim, head_dim). Apply R on head_dim (axis 0 of .T). + # W_o' @ (R @ y_h) = (W_o @ R) @ (R @ y_h) = W_o @ R^2 @ y_h = W_o @ y_h (R^2=I) + # So we need W_o' columns rotated by R: cols = O[:, h_idx, :], rotate last dim. + O[:, h_idx, :] = _fwht_along_0(O[:, h_idx, :].T).T + sd[f"blocks.{i}.attn.proj.weight"] = O.view(model_dim, model_dim).to(W_o.dtype) + + return sd + + +def _compressed_code_size(code): + code_raw = code.encode("utf-8") + minified = subprocess.run( + ["pyminify", "--no-rename-locals", "--no-hoist-literals", "--remove-literal-statements", "-"], + input=code_raw, capture_output=True, check=True, + ).stdout + compressed = lzma.compress(minified) + encoded = base64.b85encode(compressed) + wrapper = b'import lzma as L,base64 as B\nexec(L.decompress(B.b85decode("' + encoded + b'")))\n' + return len(code_raw), len(wrapper) + + +def serialize(h, base_model, code): + code_bytes_uncompressed, code_bytes = _compressed_code_size(code) + if h.is_main_process: + torch.save(base_model.state_dict(), h.model_path) + model_bytes = os.path.getsize(h.model_path) + log(f"Serialized model: {model_bytes} bytes") + log(f"Code size (uncompressed): {code_bytes_uncompressed} bytes") + log(f"Code size (compressed): {code_bytes} bytes") + sd_cpu = _unbank_state_dict(base_model.state_dict(), h.num_layers) + # Layer 2: OptRot — apply Hadamard rotation to (W_up, W_down) and (W_v, W_o) pairs + # BEFORE Hessian collection and GPTQ. The rotation is orthogonal and self-inverse: + # model outputs are preserved exactly, but weight distributions are more uniform, + # reducing int6 quantization error. The Hessian must be collected on the ROTATED model + # so that GPTQ sees the same input statistics as the rotated forward pass. + if getattr(h, "optrot_enabled", False): + log("OptRot: applying pre-GPTQ Hadamard rotation to (W_up,W_down) and (V,O) pairs...") + t_rot = time.perf_counter() + sd_cpu = _optrot_apply(sd_cpu, h) + # Reload rotated weights into base_model so Hessian collection sees rotated activations + head_dim = h.model_dim // h.num_heads + kv_dim = h.num_kv_heads * head_dim + hidden_dim = int(h.mlp_mult * h.model_dim) + rotated_banked = _rebank_state_dict(sd_cpu, h.num_layers, h.model_dim, kv_dim, hidden_dim) + base_model.load_state_dict(rotated_banked, strict=True) + log(f"OptRot: done in {time.perf_counter()-t_rot:.1f}s") + device = torch.device("cuda", h.local_rank) + t0 = time.perf_counter() + calib_mode = os.getenv("GPTQ_CALIBRATION_MODE", "random") + if calib_mode == "doc_start": + calib_loader = DocStartSequenceLoader(h, device, bos_id=BOS_ID if BOS_ID is not None else 1) + else: + calib_loader = ShuffledSequenceLoader(h, device) + log("GPTQ:collecting Hessians from calibration data...") + hessians = collect_hessians( + base_model, + calib_loader, + h, + device, + n_calibration_batches=h.gptq_calibration_batches, + ) + log(f"GPTQ:collected {len(hessians)} Hessians in {time.perf_counter()-t0:.1f}s") + quant_result, quant_meta = gptq_mixed_quantize(sd_cpu, hessians, h) + if h.compressor == "pergroup": + quant_blob = _pack_pergroup(quant_result, quant_meta) + else: + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = _compress(quant_raw, h.compressor) + quant_file_bytes = len(quant_blob) + bytes_total = quant_file_bytes + code_bytes + if h.is_main_process: + with open(h.quantized_model_path, "wb") as f: + f.write(quant_blob) + log(f"Serialized model quantized+{h.compressor}: {quant_file_bytes} bytes") + log(f"Total submission size quantized+{h.compressor}: {bytes_total} bytes") + return bytes_total, quant_file_bytes + + +def deserialize(h, device): + eval_model = GPT(h).to(device).bfloat16() + restore_fp32_params(eval_model) + flat_template = _unbank_state_dict(eval_model.state_dict(), h.num_layers) + with open(h.quantized_model_path, "rb") as f: + quant_blob_disk = f.read() + if h.compressor == "pergroup": + quant_state = _unpack_pergroup(quant_blob_disk) + else: + quant_state = torch.load( + io.BytesIO(_decompress(quant_blob_disk, h.compressor)), map_location="cpu" + ) + deq_flat = dequantize_mixed(quant_state["w"], quant_state["m"], flat_template) + head_dim = h.model_dim // h.num_heads + kv_dim = h.num_kv_heads * head_dim + hidden_dim = int(h.mlp_mult * h.model_dim) + deq_state = _rebank_state_dict(deq_flat, h.num_layers, h.model_dim, kv_dim, hidden_dim) + eval_model.load_state_dict(deq_state, strict=True) + return eval_model + + +def _loss_bpb(loss_sum, token_count, byte_count): + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + return val_loss, val_bpb + + +def eval_val(h, device, val_data, model, forward_logits_fn=None): + seq_len = h.eval_seq_len + local_batch_tokens = h.val_batch_tokens // (h.world_size * h.grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + f"VAL_BATCH_SIZE must provide at least one sequence per rank; got VAL_BATCH_SIZE={h.val_batch_tokens}, WORLD_SIZE={h.world_size}, GRAD_ACCUM_STEPS={h.grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_data.val_tokens.numel() - 1) // seq_len + seq_start = total_seqs * h.rank // h.world_size + seq_end = total_seqs * (h.rank + 1) // h.world_size + + # TODO: Don't truncate this. + seq_end = seq_start + ((seq_end - seq_start) // local_batch_seqs) * local_batch_seqs + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + run_forward_logits = ( + (model.module.forward_logits if hasattr(model, "module") else model.forward_logits) + if forward_logits_fn is None + else forward_logits_fn + ) + model.eval() + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + with torch.no_grad(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_data.val_tokens[raw_start:raw_end].to( + device=device, dtype=torch.int64, non_blocking=True + ) + x = local[:-1] + y = local[1:] + bos_pos = (x == BOS_ID).nonzero(as_tuple=True)[0].tolist() + cu_seqlens, max_seqlen = _build_cu_seqlens( + bos_pos, x.numel(), x.device, h.eval_seq_len, 64 + ) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = run_forward_logits( + x[None], cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ).detach() + per_token_loss = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y.reshape(-1), + reduction="none", + ) + val_loss_sum += per_token_loss.to(torch.float64).sum() + val_token_count += float(y.numel()) + prev_ids = x + tgt_ids = y + if val_data.caseops_enabled and val_data.val_bytes is not None: + # CaseOps: read per-token byte budget from sidecar at the same + # global positions as the target tokens y. raw_start/raw_end + # span [raw_start, raw_end), x = local[:-1], y = local[1:], + # so y is at sidecar positions [raw_start + 1, raw_end). + sidecar_slice = val_data.val_bytes[raw_start + 1 : raw_end].to( + device=device, dtype=torch.int32, non_blocking=True + ) + val_byte_count += sidecar_slice.to(torch.float64).sum() + else: + token_bytes = val_data.base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += ( + val_data.has_leading_space_lut[tgt_ids] + & ~val_data.is_boundary_token_lut[prev_ids] + ).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + model.train() + return _loss_bpb(val_loss_sum, val_token_count, val_byte_count) + + +def _find_docs(all_tokens): + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() + docs = [] + for i in range(len(bos_positions)): + start = int(bos_positions[i]) + end = ( + int(bos_positions[i + 1]) + if i + 1 < len(bos_positions) + else all_tokens.numel() + ) + if i + 1 < len(bos_positions): + end += 1 + assert end - start >= 2 + docs.append((start, end - start)) + return docs + + +def _build_ttt_global_batches(doc_entries, h, ascending=False): + batch_size = h.ttt_batch_size + global_doc_entries = sorted(doc_entries, key=lambda x: x[1][1]) + global_batches = [ + global_doc_entries[i : i + batch_size] + for i in range(0, len(global_doc_entries), batch_size) + ] + indexed = list(enumerate(global_batches)) + if not ascending: + indexed.sort(key=lambda ib: -max(dl for _, (_, dl) in ib[1])) + return indexed + + +def _init_batch_counter(path): + with open(path, "wb") as f: + f.write((0).to_bytes(4, "little")) + + +def _claim_next_batch(counter_path, queue_len): + try: + with open(counter_path, "r+b") as f: + fcntl.flock(f, fcntl.LOCK_EX) + idx = int.from_bytes(f.read(4), "little") + f.seek(0) + f.write((idx + 1).to_bytes(4, "little")) + f.flush() + except FileNotFoundError: + return queue_len + return idx + + +def _compute_chunk_window(ci, pred_len, num_chunks, chunk_size, eval_seq_len): + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_start = ci * chunk_size + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + + +def _accumulate_bpb( + ptl, + x, + y, + chunk_offsets, + chunk_lens, + pos_idx, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + loss_sum, + byte_sum, + token_count, + y_bytes=None, +): + pos = pos_idx[: x.size(1)].unsqueeze(0) + mask = ( + (chunk_lens.unsqueeze(1) > 0) + & (pos >= chunk_offsets.unsqueeze(1)) + & (pos < (chunk_offsets + chunk_lens).unsqueeze(1)) + ) + mask_f64 = mask.to(torch.float64) + if y_bytes is not None: + tok_bytes = y_bytes.to(torch.float64) + else: + tok_bytes = base_bytes_lut[y].to(torch.float64) + tok_bytes += (has_leading_space_lut[y] & ~is_boundary_token_lut[x]).to( + torch.float64 + ) + loss_sum += (ptl.to(torch.float64) * mask_f64).sum() + byte_sum += (tok_bytes * mask_f64).sum() + token_count += chunk_lens.to(torch.float64).sum() + + +def _loss_bpb_from_sums(loss_sum, token_count, byte_sum): + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_sum.item()) + return val_loss, val_bpb + + +def _add_to_counter(path, delta): + try: + with open(path, "r+b") as f: + fcntl.flock(f, fcntl.LOCK_EX) + cur = int.from_bytes(f.read(8), "little", signed=True) + cur += int(delta) + f.seek(0) + f.write(int(cur).to_bytes(8, "little", signed=True)) + f.flush() + return cur + except FileNotFoundError: + return int(delta) + + +def _init_int64_counter(path): + with open(path, "wb") as f: + f.write((0).to_bytes(8, "little", signed=True)) + + +def _select_ttt_doc_entries(docs, h): + doc_entries = list(enumerate(docs)) + if h.val_doc_fraction < 1.0: + sample_n = max(1, int(round(len(docs) * h.val_doc_fraction))) + sampled_indices = sorted( + random.Random(h.seed).sample(range(len(docs)), sample_n) + ) + return [(i, docs[i]) for i in sampled_indices] + return doc_entries + + +def train_val_ttt_global_sgd_distributed(h, device, val_data, base_model, val_tokens, batch_seqs=None): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + base_model.eval() + seq_len = h.eval_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = h.global_ttt_chunk_tokens + batch_seqs = h.global_ttt_batch_seqs if batch_seqs is None else batch_seqs + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + ttt_params = [p for p in base_model.parameters()] + for p in ttt_params: + p.requires_grad_(True) + # Layer 4: LaCT — use Muon as the fast-weight optimizer for global TTT. + # GLOBAL_TTT_OPTIMIZER=muon: Newton-Schulz orthogonalized updates for matrix params, + # SGD for scalars/vectors. Matches the training optimizer, enabling gradient norms + # compatible with learned LR scale — the key LaCT efficiency improvement. + _global_ttt_opt = getattr(h, "global_ttt_optimizer", "sgd") + if _global_ttt_opt == "muon": + _ttt_matrix_params = [p for p in ttt_params if p.ndim >= 2] + _ttt_scalar_params = [p for p in ttt_params if p.ndim < 2] + _ns_steps = int(getattr(h, "global_ttt_muon_ns_steps", 5)) + _nesterov = bool(getattr(h, "global_ttt_muon_nesterov", True)) + optimizer = Muon( + _ttt_matrix_params, + lr=h.global_ttt_lr, + momentum=h.global_ttt_momentum, + backend_steps=_ns_steps, + nesterov=_nesterov, + weight_decay=0.0, + ) + _scalar_opt = ( + torch.optim.SGD(_ttt_scalar_params, lr=h.global_ttt_lr, momentum=h.global_ttt_momentum) + if _ttt_scalar_params else None + ) + def _ttt_zero_grad(): + optimizer.zero_grad(set_to_none=True) + if _scalar_opt: + _scalar_opt.zero_grad(set_to_none=True) + def _ttt_step(): + optimizer.step() + if _scalar_opt: + _scalar_opt.step() + else: + optimizer = torch.optim.SGD( + ttt_params, lr=h.global_ttt_lr, momentum=h.global_ttt_momentum + ) + _scalar_opt = None + def _ttt_zero_grad(): + optimizer.zero_grad(set_to_none=True) + def _ttt_step(): + optimizer.step() + t_start = time.perf_counter() + for ci in range(num_chunks): + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + is_last_chunk = ci == num_chunks - 1 + if is_last_chunk or h.global_ttt_epochs <= 0: + continue + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs <= 0: + continue + warmup_chunks = max(0, min(h.global_ttt_warmup_chunks, num_chunks - 1)) + if warmup_chunks > 0 and ci < warmup_chunks: + warmup_denom = max(warmup_chunks - 1, 1) + warmup_t = ci / warmup_denom + lr_now = ( + h.global_ttt_warmup_start_lr + + (h.global_ttt_lr - h.global_ttt_warmup_start_lr) * warmup_t + ) + else: + decay_steps = max(num_chunks - 1 - warmup_chunks, 1) + decay_ci = max(ci - warmup_chunks, 0) + lr_now = h.global_ttt_lr * 0.5 * ( + 1.0 + math.cos(math.pi * decay_ci / decay_steps) + ) + for pg in optimizer.param_groups: + pg["lr"] = lr_now + if _scalar_opt is not None: + for pg in _scalar_opt.param_groups: + pg["lr"] = lr_now + my_seq_s = chunk_seqs * h.rank // h.world_size + my_seq_e = chunk_seqs * (h.rank + 1) // h.world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ in range(h.global_ttt_epochs): + for bs in range(0, my_chunk_seqs, batch_seqs): + be = min(bs + batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x_flat = local[:-1] + y_flat = local[1:] + _ttt_zero_grad() + with torch.enable_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + if h.global_ttt_respect_doc_boundaries: + bos_pos = (x_flat == BOS_ID).nonzero(as_tuple=True)[0].tolist() + cu_seqlens, max_seqlen = _build_cu_seqlens( + bos_pos, x_flat.numel(), x_flat.device, h.eval_seq_len, 64 + ) + loss = base_model( + x_flat[None], + y_flat[None], + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + else: + x = x_flat.reshape(-1, seq_len) + y = y_flat.reshape(-1, seq_len) + loss = base_model(x, y) + loss.backward() + if dist.is_available() and dist.is_initialized(): + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.SUM) + p.grad.mul_(1.0 / h.world_size) + if h.global_ttt_grad_clip > 0: + torch.nn.utils.clip_grad_norm_(ttt_params, h.global_ttt_grad_clip) + _ttt_step() + base_model.eval() + if h.rank == 0: + elapsed = time.perf_counter() - t_start + log( + f"tttg: c{ci+1}/{num_chunks} lr:{lr_now:.6f} t:{elapsed:.1f}s" + ) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + +def _compute_ngram_hints_for_val(h, val_data, log0=print): + """Precompute n-gram hints before eval timer starts (Stage 1A from PR #1967). + + Returns (hint_global, gate_global, boost_global) CPU tensors, or None if tilt disabled. + Single L->R causal pass over val tokens only — compliant with C1/C3/C4 constraints. + """ + if not getattr(h, "ngram_tilt_enabled", False): + return None + from online_ngram_tilt import build_hints_for_targets + all_tokens = val_data.val_tokens + targets_np_all = all_tokens.cpu().numpy().astype("uint16", copy=False)[1:] + t_h0 = time.perf_counter() + hints_pkg = build_hints_for_targets( + target_token_ids_np=targets_np_all, + tokenizer_path=h.tokenizer_path, + vocab_size=h.vocab_size, + log0=log0, + token_order=h.token_order, + token_threshold=h.token_threshold, + token_boost=h.token_boost, + within_tau=h.within_tau, + within_boost=h.within_boost, + word_order=h.word_order, + word_normalize=h.word_normalize, + word_tau=h.word_tau, + word_boost=h.word_boost, + agree_add_boost=h.agree_add_boost, + ) + hint_global = torch.from_numpy(hints_pkg["hint_ids"].astype("int64")) + gate_global = torch.from_numpy(hints_pkg["gate_mask"]) + boost_global = torch.from_numpy(hints_pkg["boost"].astype("float32")) + log0( + f"ngram_tilt:precompute_outside_timer_done elapsed={time.perf_counter()-t_h0:.2f}s " + f"total_targets={hint_global.numel()}" + ) + return (hint_global, gate_global, boost_global) + + +def eval_val_ttt_phased(h, base_model, device, val_data, forward_ttt_train, precomputed_hints=None): + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + TTT_LORA_EMA_DECAY = float(os.environ.get("TTT_LORA_EMA_DECAY", "0.0")) + ttt_lora_ema_enabled = TTT_LORA_EMA_DECAY > 0.0 + TTT_UPDATE_EVERY = int(os.environ.get("TTT_UPDATE_EVERY", "1")) + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + all_tokens = val_data.val_tokens + all_tokens_idx = all_tokens.to(torch.int32) + # === PR #1145 n-gram tilt: set up hint tensors (CPU) === + # hint_global[i] = hinted token id for predicting all_tokens[i+1] from prefix [:i+1]. + # gate_global[i] = True when any expert fires for position i. + # boost_global[i] = combined boost beta for position i. + ngram_hint_global = None + ngram_gate_global = None + ngram_boost_global = None + if precomputed_hints is not None: + ngram_hint_global, ngram_gate_global, ngram_boost_global = precomputed_hints + log( + f"ngram_tilt:using_precomputed_hints " + f"total_targets={ngram_hint_global.numel()} (precompute excluded from eval timer)" + ) + elif getattr(h, "ngram_tilt_enabled", False): + from online_ngram_tilt import build_hints_for_targets + targets_np_all = all_tokens.cpu().numpy().astype("uint16", copy=False)[1:] + t_h0 = time.perf_counter() + hints_pkg = build_hints_for_targets( + target_token_ids_np=targets_np_all, + tokenizer_path=h.tokenizer_path, + vocab_size=h.vocab_size, + log0=log, + token_order=h.token_order, + token_threshold=h.token_threshold, + token_boost=h.token_boost, + within_tau=h.within_tau, + within_boost=h.within_boost, + word_order=h.word_order, + word_normalize=h.word_normalize, + word_tau=h.word_tau, + word_boost=h.word_boost, + agree_add_boost=h.agree_add_boost, + ) + ngram_hint_global = torch.from_numpy(hints_pkg["hint_ids"].astype("int64")) + ngram_gate_global = torch.from_numpy(hints_pkg["gate_mask"]) + ngram_boost_global = torch.from_numpy(hints_pkg["boost"].astype("float32")) + log( + f"ngram_tilt:precompute_done elapsed={time.perf_counter()-t_h0:.2f}s " + f"total_targets={ngram_hint_global.numel()}" + ) + docs = _find_docs(all_tokens) + doc_entries = _select_ttt_doc_entries(docs, h) + prefix_doc_limit = max(0, min(len(doc_entries), int(h.phased_ttt_prefix_docs))) + num_phases = max(1, int(h.phased_ttt_num_phases)) + phase_boundaries = [] + for pi in range(num_phases): + boundary = prefix_doc_limit * (pi + 1) // num_phases + phase_boundaries.append(boundary) + current_phase = 0 + current_phase_boundary = phase_boundaries[0] + log( + "ttt_phased:" + f" total_docs:{len(doc_entries)} prefix_docs:{prefix_doc_limit} " + f"suffix_docs:{len(doc_entries) - prefix_doc_limit}" + f" num_phases:{num_phases} boundaries:{phase_boundaries}" + ) + chunk_size, eval_seq_len = h.ttt_chunk_size, h.ttt_eval_seq_len + eval_batch_set = None + if h.ttt_eval_batches: + eval_batch_set = set(int(x) for x in h.ttt_eval_batches.split(",") if x.strip()) + use_ascending = eval_batch_set is not None + global_batches_sorted = _build_ttt_global_batches( + doc_entries, h, ascending=use_ascending + ) + queue_len = len(global_batches_sorted) + counter_path = f"/tmp/ttt_counter_{h.run_id}" + prefix_counter_path = f"/tmp/ttt_prefix_counter_{h.run_id}" + pause_flag_path = f"/tmp/ttt_pause_flag_{h.run_id}" + if h.rank == 0: + _init_batch_counter(counter_path) + _init_int64_counter(prefix_counter_path) + try: + os.remove(pause_flag_path) + except FileNotFoundError: + pass + if dist.is_available() and dist.is_initialized(): + path_list = [counter_path, prefix_counter_path, pause_flag_path] + dist.broadcast_object_list(path_list, src=0) + counter_path, prefix_counter_path, pause_flag_path = path_list + dist.barrier() + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + t_start = time.perf_counter() + reusable_lora = BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + reusable_ema_lora = ( + BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + if ttt_lora_ema_enabled else None + ) + + def _build_opt(lora): + if h.ttt_optimizer == "sgd": + return torch.optim.SGD( + lora.parameters(), lr=h.ttt_lora_lr, + momentum=h.ttt_beta1, weight_decay=h.ttt_weight_decay, + ) + return torch.optim.AdamW( + lora.parameters(), lr=h.ttt_lora_lr, + betas=(h.ttt_beta1, h.ttt_beta2), + eps=1e-10, weight_decay=h.ttt_weight_decay, fused=True, + ) + + reusable_opt = _build_opt(reusable_lora) + local_scored_docs = [] + global_ttt_done = prefix_doc_limit == 0 + try: + while True: + queue_idx = _claim_next_batch(counter_path, queue_len) + if queue_idx >= queue_len: + break + orig_batch_idx, batch_entries = global_batches_sorted[queue_idx] + batch = [doc for _, doc in batch_entries] + bsz = len(batch) + prev_loss = loss_sum.item() + prev_bytes = byte_sum.item() + prev_tokens = token_count.item() + if bsz == reusable_lora.bsz: + reusable_lora.reset() + for s in reusable_opt.state.values(): + for k, v in s.items(): + if isinstance(v, torch.Tensor): + v.zero_() + elif k == "step": + s[k] = 0 + cur_lora = reusable_lora + cur_opt = reusable_opt + if ttt_lora_ema_enabled: + reusable_ema_lora.reset() + with torch.no_grad(): + for ema_p, raw_p in zip(reusable_ema_lora.parameters(), cur_lora.parameters()): + ema_p.data.copy_(raw_p.data) + cur_ema_lora = reusable_ema_lora + else: + cur_lora = BatchedTTTLoRA( + bsz, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + cur_opt = _build_opt(cur_lora) + if ttt_lora_ema_enabled: + cur_ema_lora = BatchedTTTLoRA( + bsz, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + with torch.no_grad(): + for ema_p, raw_p in zip(cur_ema_lora.parameters(), cur_lora.parameters()): + ema_p.data.copy_(raw_p.data) + pred_lens = [doc_len - 1 for _, doc_len in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + num_chunks_t = torch.tensor(num_chunks, dtype=torch.int64, device=device) + for ci in range(max_nc): + active = [ci < nc for nc in num_chunks] + needs_train = any(ci < nc - 1 for nc in num_chunks) + tok_starts = torch.zeros(bsz, dtype=torch.int64) + tok_wls = torch.zeros(bsz, dtype=torch.int64) + chunk_offsets_cpu = torch.zeros(bsz, dtype=torch.int64) + chunk_lens_cpu = torch.zeros(bsz, dtype=torch.int64) + for b in range(bsz): + if not active[b]: + continue + doc_start, doc_len = batch[b] + win_start, win_len, chunk_offset, chunk_len = _compute_chunk_window( + ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len + ) + tok_starts[b] = doc_start + win_start + tok_wls[b] = win_len + chunk_offsets_cpu[b] = chunk_offset + chunk_lens_cpu[b] = chunk_len + _, context_size, chunk_offset, _ = _compute_chunk_window( + ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len + ) + col_idx = torch.arange(context_size + 1) + idx = tok_starts.unsqueeze(1) + col_idx.unsqueeze(0) + idx.clamp_(max=all_tokens.numel() - 1) + gathered_gpu = all_tokens_idx[idx].to( + device=device, dtype=torch.int64, non_blocking=True + ) + valid = (col_idx[:context_size].unsqueeze(0) < tok_wls.unsqueeze(1)).to( + device, non_blocking=True + ) + chunk_offsets = chunk_offsets_cpu.to(device, non_blocking=True) + chunk_lens = chunk_lens_cpu.to(device, non_blocking=True) + x = torch.where(valid, gathered_gpu[:, :context_size], 0) + y = torch.where(valid, gathered_gpu[:, 1 : context_size + 1], 0) + ctx_pos = torch.arange(context_size, device=device, dtype=torch.int64) + # n-gram tilt: gather hints aligned to y for this chunk + hint_ids_gpu = None + gate_mask_gpu = None + boost_gpu = None + if ngram_hint_global is not None: + hint_idx_cpu = ( + tok_starts.unsqueeze(1) + col_idx[:context_size].unsqueeze(0) + ).clamp_(min=0, max=ngram_hint_global.numel() - 1) + hint_ids_gpu = ngram_hint_global[hint_idx_cpu].to( + device=device, dtype=torch.int64, non_blocking=True + ) + gate_mask_gpu = ngram_gate_global[hint_idx_cpu].to( + device=device, non_blocking=True + ) + boost_gpu = ngram_boost_global[hint_idx_cpu].to( + device=device, dtype=torch.float32, non_blocking=True + ) + hint_ids_gpu = torch.where(valid, hint_ids_gpu, torch.zeros_like(hint_ids_gpu)) + gate_mask_gpu = gate_mask_gpu & valid + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + if hint_ids_gpu is not None: + per_tok_loss, log_q_hint = forward_ttt_train( + x, y, lora=cur_ema_lora if ttt_lora_ema_enabled else cur_lora, + hint_ids=hint_ids_gpu, + ) + else: + per_tok_loss = forward_ttt_train( + x, y, lora=cur_ema_lora if ttt_lora_ema_enabled else cur_lora + ) + log_q_hint = None + # Apply closed-form tilt to BPB accumulation only (not to TTT training objective). + if hint_ids_gpu is not None and log_q_hint is not None: + from online_ngram_tilt import apply_tilt_to_ptl_torch_fast + tilted_loss = apply_tilt_to_ptl_torch_fast( + ptl=per_tok_loss, + log_q_hint=log_q_hint, + target_ids=y, + hint_ids=hint_ids_gpu, + gate_mask=gate_mask_gpu, + boost=boost_gpu, + ) + else: + tilted_loss = per_tok_loss + # CaseOps sidecar-driven byte budget. Mirror the index pattern + # used to build y from all_tokens: y[b, j] corresponds to the + # token at global position tok_starts[b] + 1 + j (when valid). + y_bytes_arg = None + if val_data.caseops_enabled and val_data.val_bytes is not None: + y_idx = ( + tok_starts.unsqueeze(1) + + 1 + + col_idx[:context_size].unsqueeze(0) + ) + y_idx = y_idx.clamp_(max=val_data.val_bytes.numel() - 1) + y_bytes_arg = val_data.val_bytes[y_idx].to( + device=device, dtype=torch.int32, non_blocking=True + ) + # Mirror the `valid` masking used for y so out-of-range tokens + # contribute zero bytes (matches y=0 substitution above). + y_bytes_arg = torch.where( + valid, y_bytes_arg, torch.zeros_like(y_bytes_arg) + ) + with torch.no_grad(): + _accumulate_bpb( + tilted_loss, + x, + y, + chunk_offsets, + chunk_lens, + ctx_pos, + val_data.base_bytes_lut, + val_data.has_leading_space_lut, + val_data.is_boundary_token_lut, + loss_sum, + byte_sum, + token_count, + y_bytes=y_bytes_arg, + ) + if needs_train: + activate_chunk_mask = (num_chunks_t - 1 > ci).float() + is_last_trained_chunk = (ci == max_nc - 2) + is_update_step = (ci % TTT_UPDATE_EVERY == TTT_UPDATE_EVERY - 1) or is_last_trained_chunk + is_window_start = (ci % TTT_UPDATE_EVERY == 0) + for gi in range(h.ttt_grad_steps): + if gi > 0 or ttt_lora_ema_enabled: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + per_tok_loss = forward_ttt_train(x, y, lora=cur_lora) + per_doc = per_tok_loss[ + :, chunk_offset : chunk_offset + chunk_size + ].mean(dim=-1) + # Original path: zero_grad only at the start of an update window + # (gi==0). With TTT_GRAD_STEPS>1 this leaves gi=0's gradient in + # .grad when gi=1 runs, so step 2 sees (grad_gi0 + grad_gi1) — + # effectively ~2x gradient magnitude on the second update. + # Clean path (TTT_GRAD_STEPS_CLEAN=1): also zero_grad before every + # gi>0 step so each optimizer.step() sees only its own fresh gradient, + # giving true independent half-LR updates with different curvature. + if (is_window_start and gi == 0) or (h.ttt_grad_steps_clean and gi > 0): + cur_opt.zero_grad(set_to_none=True) + (per_doc * activate_chunk_mask).sum().backward() + if is_update_step: + cur_opt.step() + if ttt_lora_ema_enabled and is_update_step: + with torch.no_grad(): + decay = TTT_LORA_EMA_DECAY + for ema_p, raw_p in zip(cur_ema_lora.parameters(), cur_lora.parameters()): + ema_p.data.mul_(decay).add_(raw_p.data, alpha=1.0 - decay) + else: + del per_tok_loss + batch_num = orig_batch_idx + 1 + doc_lens = [dl for _, dl in batch] + should_report = batch_num in eval_batch_set if eval_batch_set is not None else True + if should_report: + cur_tokens = token_count.item() + cur_loss_val = loss_sum.item() + cur_bytes_val = byte_sum.item() + dt = cur_tokens - prev_tokens + db = cur_bytes_val - prev_bytes + if dt > 0 and db > 0: + b_loss = (cur_loss_val - prev_loss) / dt + b_bpb = b_loss / math.log(2.0) * (dt / db) + else: + b_loss = b_bpb = 0.0 + r_loss = cur_loss_val / max(cur_tokens, 1) + r_bpb = r_loss / math.log(2.0) * (cur_tokens / max(cur_bytes_val, 1)) + elapsed = time.perf_counter() - t_start + log( + f"ttp: b{batch_num}/{queue_len} bl:{b_loss:.4f} bb:{b_bpb:.4f} " + f"rl:{r_loss:.4f} rb:{r_bpb:.4f} dl:{min(doc_lens)}-{max(doc_lens)} " + f"gd:{int(global_ttt_done)}" + ) + if not global_ttt_done: + local_scored_docs.extend( + (orig_batch_idx, pos, doc_start, doc_len) + for pos, (doc_start, doc_len) in enumerate(batch) + ) + prefix_done = _add_to_counter(prefix_counter_path, len(batch_entries)) + if prefix_done >= current_phase_boundary: + try: + with open(pause_flag_path, "x"): + pass + except FileExistsError: + pass + should_pause = os.path.exists(pause_flag_path) + if should_pause: + if dist.is_available() and dist.is_initialized(): + dist.barrier() + gathered_scored_docs = [None] * h.world_size + if dist.is_available() and dist.is_initialized(): + dist.all_gather_object(gathered_scored_docs, local_scored_docs) + else: + gathered_scored_docs = [local_scored_docs] + scored_docs_for_global = [] + for rank_docs in gathered_scored_docs: + if rank_docs: + scored_docs_for_global.extend(rank_docs) + scored_docs_for_global.sort(key=lambda x: (x[0], x[1])) + scored_docs_for_global = scored_docs_for_global[:current_phase_boundary] + scored_token_chunks = [ + val_data.val_tokens[doc_start : doc_start + doc_len] + for _, _, doc_start, doc_len in scored_docs_for_global + ] + if scored_token_chunks: + global_ttt_tokens = torch.cat(scored_token_chunks) + else: + global_ttt_tokens = val_data.val_tokens[:0] + if h.rank == 0: + prefix_done = 0 + try: + with open(prefix_counter_path, "rb") as f: + prefix_done = int.from_bytes( + f.read(8), "little", signed=True + ) + except FileNotFoundError: + pass + log( + f"ttpp: phase:{current_phase + 1}/{num_phases} pd:{prefix_done} " + f"gd:{len(scored_docs_for_global)} " + f"t:{time.perf_counter() - t_start:.1f}s" + ) + train_val_ttt_global_sgd_distributed( + h, device, val_data, base_model, global_ttt_tokens + ) + for p in base_model.parameters(): + p.requires_grad_(False) + reusable_lora = BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + reusable_opt = _build_opt(reusable_lora) + if ttt_lora_ema_enabled: + reusable_ema_lora = BatchedTTTLoRA( + h.ttt_batch_size, base_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + current_phase += 1 + if current_phase >= num_phases: + global_ttt_done = True + else: + current_phase_boundary = phase_boundaries[current_phase] + if h.rank == 0: + try: + os.remove(pause_flag_path) + except FileNotFoundError: + pass + if dist.is_available() and dist.is_initialized(): + dist.barrier() + if h.rank == 0: + log(f"ttpr: phase:{current_phase}/{num_phases} t:{time.perf_counter() - t_start:.1f}s") + del cur_lora, cur_opt + if ttt_lora_ema_enabled: + del cur_ema_lora + finally: + pass + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.train() + return _loss_bpb_from_sums(loss_sum, token_count, byte_sum) + + +def timed_eval(label, fn, *args, **kwargs): + torch.cuda.synchronize() + t0 = time.perf_counter() + val_loss, val_bpb = fn(*args, **kwargs) + torch.cuda.synchronize() + elapsed_ms = 1e3 * (time.perf_counter() - t0) + log( + f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms" + ) + return val_loss, val_bpb + + +def train_model(h, device, val_data): + base_model = GPT(h).to(device).bfloat16() + restore_fp32_params(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + compiled_forward_logits = torch.compile( + base_model.forward_logits, dynamic=False, fullgraph=True + ) + model = compiled_model + log(f"model_params:{sum(p.numel()for p in base_model.parameters())}") + optimizers = Optimizers(h, base_model) + train_loader = DocumentPackingLoader(h, device) + max_wallclock_ms = ( + 1e3 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None + ) + if max_wallclock_ms is not None: + max_wallclock_ms -= h.gptq_reserve_seconds * 1e3 + log( + f"gptq:reserving {h.gptq_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms" + ) + + def training_frac(step, elapsed_ms): + if max_wallclock_ms is None: + return step / max(h.iterations, 1) + return elapsed_ms / max(max_wallclock_ms, 1e-09) + + def lr_mul(frac): + if h.warmdown_frac <= 0: + return 1.0 + if frac >= 1.0 - h.warmdown_frac: + return max((1.0 - frac) / h.warmdown_frac, h.min_lr) + return 1.0 + + def step_fn(step, lr_scale): + optimizers.zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(h.grad_accum_steps): + x, y, cu_seqlens, _max_seqlen = train_loader.next_batch( + h.train_batch_tokens, h.grad_accum_steps + ) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y, cu_seqlens=cu_seqlens, max_seqlen=h.train_seq_len) + train_loss += loss.detach() + (loss / h.grad_accum_steps).backward() + train_loss /= h.grad_accum_steps + frac = ( + min(step / h.muon_momentum_warmup_steps, 1.0) + if h.muon_momentum_warmup_steps > 0 + else 1.0 + ) + muon_momentum = ( + 1 - frac + ) * h.muon_momentum_warmup_start + frac * h.muon_momentum + for group in optimizers.optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * lr_scale + if h.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), h.grad_clip_norm) + optimizers.step(distributed=h.distributed) + return train_loss + + if h.warmup_steps > 0: + initial_model_state = { + name: tensor.detach().cpu().clone() + for (name, tensor) in base_model.state_dict().items() + } + initial_optimizer_states = [ + copy.deepcopy(opt.state_dict()) for opt in optimizers + ] + model.train() + num_tokens_local = h.train_batch_tokens // h.world_size + for blk in base_model.blocks: + blk.attn.rotary(num_tokens_local, device, torch.bfloat16) + cu_bucket_size = train_loader.cu_bucket_size + warmup_cu_buckets = tuple(cu_bucket_size * i for i in range(1, 5)) + warmup_cu_iters = 3 + x, y, cu_seqlens, _ = train_loader.next_batch( + h.train_batch_tokens, h.grad_accum_steps + ) + log(f"warmup_cu_buckets:{','.join(str(b) for b in warmup_cu_buckets)} iters_each:{warmup_cu_iters}") + def _run_cu_bucket_warmup(): + for bucket_len in warmup_cu_buckets: + boundaries = list(range(0, x.size(1), max(h.train_seq_len, 1))) + if boundaries[-1] != x.size(1): + boundaries.append(x.size(1)) + cu = torch.full((bucket_len,), x.size(1), dtype=torch.int32, device=device) + cu[: len(boundaries)] = torch.tensor(boundaries, dtype=torch.int32, device=device) + for _ in range(warmup_cu_iters): + optimizers.zero_grad_all() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + wloss = model(x, y, cu_seqlens=cu, max_seqlen=h.train_seq_len) + (wloss / h.grad_accum_steps).backward() + optimizers.zero_grad_all() + _run_cu_bucket_warmup() + if h.num_loops > 0: + base_model.looping_active = True + _run_cu_bucket_warmup() + base_model.looping_active = False + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if ( + warmup_step <= 5 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == h.warmup_steps + ): + log(f"warmup_step: {warmup_step+1}/{h.warmup_steps}") + if h.num_loops > 0: + base_model.looping_active = True + log( + f"loop_warmup:enabled encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}" + ) + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if ( + warmup_step <= 5 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == h.warmup_steps + ): + log(f"loop_warmup_step: {warmup_step+1}/{h.warmup_steps}") + base_model.looping_active = False + base_model.load_state_dict(initial_model_state, strict=True) + for (opt, state) in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + optimizers.zero_grad_all() + train_loader = DocumentPackingLoader(h, device) + ema_state = { + name: t.detach().float().clone() + for (name, t) in base_model.state_dict().items() + } + ema_decay = h.ema_decay + training_time_ms = 0.0 + stop_after_step = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = ( + step == h.iterations + or stop_after_step is not None + and step >= stop_after_step + ) + should_validate = ( + last_step or h.val_loss_every > 0 and step % h.val_loss_every == 0 + ) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1e3 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + h, device, val_data, model, compiled_forward_logits + ) + log( + f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < h.iterations: + log( + f"stopping_early: wallclock_cap train_time: {training_time_ms:.0f}ms step: {step}/{h.iterations}" + ) + break + elapsed_ms = training_time_ms + 1e3 * (time.perf_counter() - t0) + frac = training_frac(step, elapsed_ms) + scale = lr_mul(frac) + if ( + h.num_loops > 0 + and not base_model.looping_active + and frac >= h.enable_looping_at + ): + base_model.looping_active = True + log( + f"layer_loop:enabled step:{step} frac:{frac:.3f} encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}" + ) + train_loss = step_fn(step, scale) + with torch.no_grad(): + for (name, t) in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_( + t.detach().float(), alpha=1.0 - ema_decay + ) + step += 1 + approx_training_time_ms = training_time_ms + 1e3 * (time.perf_counter() - t0) + should_log_train = h.train_log_every > 0 and ( + step <= 5 or step % h.train_log_every == 0 or stop_after_step is not None + ) + if should_log_train: + tok_per_sec = step * h.train_batch_tokens / (approx_training_time_ms / 1e3) + log( + f"{step}/{h.iterations} train_loss: {train_loss.item():.4f} train_time: {approx_training_time_ms/60000:.1f}m tok/s: {tok_per_sec:.0f}" + ) + reached_cap = ( + max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + ) + if h.distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log( + f"peak memory allocated: {torch.cuda.max_memory_allocated()//1024//1024} MiB reserved: {torch.cuda.max_memory_reserved()//1024//1024} MiB" + ) + log("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = { + name: t.to(dtype=current_state[name].dtype) for (name, t) in ema_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + return base_model, compiled_model, compiled_forward_logits + + +def train_and_eval(h, device): + random.seed(h.seed) + np.random.seed(h.seed) + torch.manual_seed(h.seed) + torch.cuda.manual_seed_all(h.seed) + if h.artifact_dir and h.is_main_process: + os.makedirs(h.artifact_dir, exist_ok=True) + val_data = ValidationData(h, device) + log( + f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob('fineweb_train_*.bin')))}" + ) + log(f"val_tokens: {val_data.val_tokens.numel()-1}") + # TTT_EVAL_ONLY: skip training + GPTQ, jump straight to TTT eval on a + # pre-existing quantized artifact. Used to test TTT-only improvements + # (e.g., PR-1767's alpha/warm-start/WD) without retraining. + ttt_eval_only = os.environ.get("TTT_EVAL_ONLY", "0") == "1" + if ttt_eval_only: + log("TTT_EVAL_ONLY=1 — skipping training + GPTQ, loading saved artifact for TTT eval") + log(f"ttt_lora_alpha: {BatchedLinearLoRA._ALPHA}") + log(f"ttt_warm_start_a: {BatchedLinearLoRA._WARM_START_A}") + log(f"ttt_weight_decay: {h.ttt_weight_decay}") + else: + base_model, compiled_model, compiled_forward_logits = train_model( + h, device, val_data + ) + torch._dynamo.reset() + timed_eval( + "diagnostic pre-quantization post-ema", + eval_val, + h, + device, + val_data, + compiled_model, + compiled_forward_logits, + ) + if os.environ.get("PREQUANT_ONLY", "0") == "1": + log("PREQUANT_ONLY=1 — skipping serialize/GPTQ/post-quant eval/TTT") + return + serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) + if h.distributed: + dist.barrier() + eval_model = deserialize(h, device) + if h.num_loops > 0: + eval_model.looping_active = True + if not ttt_eval_only: + compiled_model = torch.compile(eval_model, dynamic=False, fullgraph=True) + compiled_forward_logits = torch.compile( + eval_model.forward_logits, dynamic=False, fullgraph=True + ) + timed_eval( + "diagnostic quantized", + eval_val, + h, + device, + val_data, + compiled_model, + compiled_forward_logits, + ) + del eval_model + if h.ttt_enabled: + if not ttt_eval_only: + del compiled_model + if ttt_eval_only: + del eval_model + torch._dynamo.reset() + torch.cuda.empty_cache() + ttt_model = deserialize(h, device) + if h.num_loops > 0: + ttt_model.looping_active = True + for p in ttt_model.parameters(): + p.requires_grad_(False) + + if h.rope_yarn: + _yarn_seqlen = h.train_batch_tokens // h.grad_accum_steps + for block in ttt_model.blocks: + block.attn.rotary(_yarn_seqlen, device, torch.bfloat16) + else: + for block in ttt_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + block.attn.rotary._seq_len_cached = 0 + block.attn.rotary(h.ttt_eval_seq_len, device, torch.bfloat16) + + def _fwd_ttt_inner(input_ids, target_ids, lora): + return ttt_model.forward_ttt(input_ids, target_ids, lora=lora) + + def _fwd_ttt_inner_with_hints(input_ids, target_ids, lora, hint_ids): + return ttt_model.forward_ttt(input_ids, target_ids, lora=lora, hint_ids=hint_ids) + + _fwd_ttt_compiled_inner = None + _fwd_ttt_compiled_inner_hints = None + + def _fwd_ttt(input_ids, target_ids, lora, hint_ids=None): + nonlocal _fwd_ttt_compiled_inner, _fwd_ttt_compiled_inner_hints + if hint_ids is None: + if _fwd_ttt_compiled_inner is None: + _fwd_ttt_compiled_inner = torch.compile(_fwd_ttt_inner, dynamic=True) + return _fwd_ttt_compiled_inner(input_ids, target_ids, lora=lora) + if _fwd_ttt_compiled_inner_hints is None: + _fwd_ttt_compiled_inner_hints = torch.compile( + _fwd_ttt_inner_with_hints, dynamic=True + ) + return _fwd_ttt_compiled_inner_hints( + input_ids, target_ids, lora=lora, hint_ids=hint_ids + ) + + fwd_ttt_compiled = _fwd_ttt + log(f"ttt_grad_steps: {h.ttt_grad_steps} ttt_grad_steps_clean: {h.ttt_grad_steps_clean}") + log(f"ttt_lora:warming up compile (random tokens, no val data)") + global BOS_ID + if BOS_ID is None: + BOS_ID = 1 + t_warmup = time.perf_counter() + warmup_bszes = [h.ttt_batch_size] + for bsz in warmup_bszes: + wl = BatchedTTTLoRA( + bsz, ttt_model, h.ttt_lora_rank, + k_lora=h.ttt_k_lora, mlp_lora=h.ttt_mlp_lora, o_lora=h.ttt_o_lora, + ).to(device) + wo = torch.optim.AdamW( + wl.parameters(), + lr=h.ttt_lora_lr, + betas=(h.ttt_beta1, h.ttt_beta2), + eps=1e-10, + weight_decay=h.ttt_weight_decay, + fused=True, + ) + for ctx_len in (h.ttt_chunk_size, h.ttt_eval_seq_len): + xw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) + yw = torch.randint(0, h.vocab_size, (bsz, ctx_len), device=device, dtype=torch.int64) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = fwd_ttt_compiled(xw, yw, lora=wl) + ptl[:, : min(h.ttt_chunk_size, ctx_len)].mean(dim=-1).sum().backward() + wo.step() + wo.zero_grad(set_to_none=True) + del wl, wo + torch.cuda.empty_cache() + compile_elapsed = time.perf_counter() - t_warmup + log(f"ttt_lora:compile warmup done ({compile_elapsed:.1f}s)") + # v5 Stage 1A: precompute n-gram hints BEFORE eval timer (single causal pass, + # val tokens only — same compliance as inline). Saves ~168s of measured eval + # time for full tilt without any loss of tilt benefit. + precomputed_hints = None + if h.ngram_tilt_enabled and h.ngram_hint_precompute_outside: + log("ngram_tilt:precomputing hints OUTSIDE eval timer") + precomputed_hints = _compute_ngram_hints_for_val(h, val_data, log0=log) + log("\nbeginning TTT eval timer") + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_ttt_phased( + h, ttt_model, device, val_data, forward_ttt_train=fwd_ttt_compiled, + precomputed_hints=precomputed_hints, + ) + torch.cuda.synchronize() + ttt_eval_elapsed = time.perf_counter() - t_ttt + log( + "quantized_ttt_phased " + f"val_loss:{ttt_val_loss:.8f} val_bpb:{ttt_val_bpb:.8f} " + f"eval_time:{1e3*ttt_eval_elapsed:.0f}ms" + ) + log(f"total_eval_time:{ttt_eval_elapsed:.1f}s") + del ttt_model + + +def main(): + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError( + f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral" + ) + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision("high") + from torch.backends.cuda import ( + enable_cudnn_sdp, + enable_flash_sdp, + enable_math_sdp, + enable_mem_efficient_sdp, + ) + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.cache_size_limit = 16 + h = Hyperparameters() + set_logging_hparams(h) + if h.is_main_process: + os.makedirs(h.artifact_dir if h.artifact_dir else "logs", exist_ok=True) + log(100 * "=", console=False) + log("Hyperparameters:", console=True) + for (k, v) in sorted(vars(type(h)).items()): + if not k.startswith("_"): + log(f" {k}: {v}", console=True) + log("=" * 100, console=False) + log("Source code:", console=False) + log("=" * 100, console=False) + with open(__file__, "r", encoding="utf-8") as _src: + log(_src.read(), console=False) + log("=" * 100, console=False) + log(f"Running Python {sys.version}", console=False) + log(f"Running PyTorch {torch.__version__}", console=False) + log("=" * 100, console=False) + train_and_eval(h, device) + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main()