diff --git a/records/track_non_record_16mb/2026-04-30_KillMamba2_TriParallel_n7_Ternary_EMA_4xH200_1hr/README.md b/records/track_non_record_16mb/2026-04-30_KillMamba2_TriParallel_n7_Ternary_EMA_4xH200_1hr/README.md new file mode 100644 index 0000000000..60a1dabd3e --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_KillMamba2_TriParallel_n7_Ternary_EMA_4xH200_1hr/README.md @@ -0,0 +1,84 @@ +First SSM-based entry in either track. Trained on 4×H200 SXM for one hour rather than 8×H100 for ten minutes — non-record on both hardware and time. Eval is the standard root-harness full-window val + int8/brotli quant; no TTT, no sliding-window, no GPTQ. + +**val_bpb 1.30040 / 12.08 MB / seed=1337** + +## Architecture + +**Topology.** 7 distinct transformer-style blocks, each shared across 3 sequential applications via depth recurrence (`NUM_UNIQUE_LAYERS=7 NUM_LOOPS=3`). Effective compute depth is 21; total stored body parameters are equivalent to 7 layers. No U-Net skip; the loop is plain weight reuse. `MODEL_DIM=512`, tied input/output embeddings, sp1024 vocab. + +**Block contents.** Each block is `(attn || SSM) + MLP`, not the standard `attn → MLP` chain: + +1. RMSNorm → run attention and the SSM in parallel on the same normalized input → sum their outputs (with independent per-channel learned scales `attn_scale`, `s4d_scale`) into the residual stream. +2. RMSNorm → SwiGLU MLP (`MLP_MULT=8`, hidden = 8·dim) → add to the residual stream with its own per-channel scale (`mlp_scale`). +3. A learned 2-vector `resid_mix` per block interpolates the incoming residual between the live stream `x` and the original post-embedding `x0` before normalization. Cheap, lets each block decide how much it cares about deep context vs the original embedding. + +**Attention branch.** Standard causal multi-head attention with grouped-query (`NUM_HEADS=8`, `NUM_KV_HEADS=4`) and RoPE positional encoding. + +**SSM branch — kill-Mamba-2.** Standard Mamba-2 has an `in_proj` that produces `x` plus three input-dependent quantities `dt`, `B`, `C` ("selectivity" — the ability to modulate the recurrence per-token), runs a depthwise causal `conv1d` (kernel=4) on `x`, then an SSD chunkwise selective scan (`d_state=64`, `expand=2`, `chunk_size=64`, `headdim=64`, 16 SSD heads), then `out_proj`. **kill-Mamba-2 replaces `dt`, `B`, `C` with learned per-head/per-state constants** (`_B_const`, `_C_const`) and a per-head `dt_bias`, making the recurrence linear time-invariant (LTI) instead of input-dependent. Same conv1d, same gating, same `A_log`, same `D_skip`, same in/out projections — only the dynamics become LTI. The intuition: at sub-records training scale, the input-dependent projections are under-trained and add noise; the LTI variant keeps the structural advantages of Mamba-2 (conv1d local recall, gated SSD scan) without the gradient surface area of selectivity. + +**Why parallel attention || SSM.** The two mixers have complementary recall: attention does exact content-addressable lookup over the context, conv1d-equipped SSM does structured local recall. Running them on the same normalized input and summing outputs is consistently better at this scale than alternating attention-only and SSM-only blocks (cross-class hybrid finding from earlier experiments). + +**No BigramHash.** BigramHash recall (`BIGRAM_VOCAB_SIZE=0` here) helps S4D-Lin family blocks but interacts negatively with the conv1d already present in Mamba-2 — it ends up adding noise to the recall niche conv1d already occupies. + +**Quantization (`TERNARY_BODY=1`).** Body weights of the attention `qkv`/`out` projections, Mamba-2 `in_proj`/`out_proj`, and SwiGLU MLP gates are constrained to `{−γ, 0, +γ}` ternary via BitNet-b1.58 absmean straight-through estimation: at every forward pass each weight matrix is quantized to ternary (with a per-row scale γ = mean absolute value) before the matmul; gradients pass through unchanged. Roughly 1.58 bits per parameter of effective resolution. At quant export, ternary weights are stored as 2 bits per parameter packed (4 vals per byte) in a custom format that bypasses int8 entirely — lossless ternary→ternary round-trip. 1D and small (≤65,536-element) tensors — the SSM dynamics buffers `A_log`, `B_proj`, `C_proj`, `dt_bias`, `D_skip`, `conv1d` weights, all RMSNorm scales, all per-channel scales — stay fp32 throughout via `CONTROL_TENSOR_NAME_PATTERNS`. + +**EMA-of-weights (`EMA_BETA=0.999`).** A shadow copy of the model weights is updated each step as `shadow = β·shadow + (1−β)·model`, then swapped into the model immediately before final eval. β=0.999 gives an effective averaging window of ~1000 steps — about the last 23% of this run's 4,380 steps. EMA is reliable when `(1−β)·steps ≫ 1`; for shorter runs (<3,000 steps) β=0.99 is the right choice. + +**Optimizer.** Muon (Newton-Schulz orthogonalization, `MUON_BACKEND_STEPS=15`) for 2D weight matrices; AdamW for low-dim parameters and embeddings. `MATRIX_LR=0.045` is the modded-nanogpt baseline value. The optimizer split is by tensor pattern, same as the records-track convention. + +**Compression.** int8 quantization of fp32 buffers + 2-bit packed ternary for body weights → brotli q=11 → final artifact. brotli/zlib ratio on this bytestream is ~0.985; brotli is the standard records-track choice. + +## Configuration summary + +- Track: `non-record`, 16 MB cap +- `NUM_UNIQUE_LAYERS=7 NUM_LOOPS=3` (effective depth 21) +- `MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4`, tied embeddings, sp1024 +- `PARALLEL_LAYER_POSITIONS=0,1,2,3,4,5,6 PARALLEL_SSM_TYPE=mamba2_kill MAMBA2_KILL_SELECTIVITY=1` +- `BIGRAM_VOCAB_SIZE=0` (off) +- `TERNARY_BODY=1` (BitNet-b1.58, exported 2-bit packed) +- `EMA_BETA=0.999` (shadow swapped at last step) +- Schedule: `WARMDOWN_ITERS=1800 LR_WARMUP_STEPS=30 MATRIX_LR=0.045 TIED_EMBED_INIT_STD=0.05 MUON_BACKEND_STEPS=15 TRAIN_BATCH_TOKENS=524288` + +## Command + +`train_gpt.py` is fully self-contained — `BitLinear` plus `pack_ternary` / `unpack_ternary` are inlined directly into the script (no local helper modules). Run from the **repo root**, same convention as every other records-folder submission: + +```bash +source records/track_non_record_16mb/2026-04-30_KillMamba2_TriParallel_n7_Ternary_EMA_4xH200_1hr/env.sh +torchrun --standalone --nproc_per_node=4 \ + records/track_non_record_16mb/2026-04-30_KillMamba2_TriParallel_n7_Ternary_EMA_4xH200_1hr/train_gpt.py +``` + +`env.sh` sets `CONTROL_TENSOR_NAME_PATTERNS` (load-bearing — keeps SSM dynamics buffers fp32 under ternary quantization), the topology / quant / EMA / optimizer knobs, and the eval cadence. `DATA_PATH` and `TOKENIZER_PATH` are not exported because their script-side defaults (`./data/datasets/fineweb10B_sp1024` and `./data/tokenizers/fineweb_1024_bpe.model`) already resolve correctly from the repo root. `MAX_WALLCLOCK_SECONDS=3600` is the binding cap; `ITERATIONS=20000` is just an upper bound. + +## Key metrics + +- Pre-quant val_bpb: `1.2983` +- Post-quant val_bpb: `1.30040229` +- Quant tax: `0.0021` +- Wallclock: `3,600s` (cap fired) +- Step time: `821.94 ms` +- Steps: `4,380` +- Tokens trained: `4,380 × 524,288 ≈ 2.30B` +- Code size: `106,722 bytes` (`train_gpt.py`; single-file, no local helper modules) +- Compressed model payload (int8 + 2-bit-packed ternary + brotli q=11): `11,969,746 bytes` +- **Artifact total: `12,076,468 bytes` = code + payload (≈12.08 MB; 16 MB cap honored with ~3.92 MB headroom)** +- Model parameters: `61,657,752` +- Hardware: 4×H200 SXM (141GB HBM3e per GPU), `--nproc 4`, grad_accum=2 (peak GPU memory at run: 114,125 MiB allocated) + +The train log (`train_seed1337.log`) shows `code_bytes:104676` and `Total submission size int8+zlib:12074422` because it was generated by the original three-file run (`train_gpt.py` + `modules/bitlinear.py` + `modules/trigram_side_memory.py`). The shipped artifact is now single-file (`bitlinear.py` inlined into `train_gpt.py`; `trigram_side_memory.py` deleted as gated dead code under default `TRIGRAM_SIDE_MEMORY=0`), so the post-cleanup byte counts above replace the log values. The trained model checkpoint is unchanged — only the code-side accounting moved. + +## Comparison + +- vs `track_10min_16mb` naive baseline (1.2244, 9L 512d sp1024 GQA-4 tied-emb, 8×H100 10 min): **+0.076 BPB worse** despite ~9× the compute budget. This submission is below the naive records-track baseline; it earns its place on the non-record list as the **first SSM entry** rather than as a frontier number. +- vs `track_10min_16mb` mid-tier records-track frontier `2026-03-31_ParallelResiduals_MiniDepthRecurrence` (1.1063): +0.194 BPB. Most of this gap is duration + missing standard-stack ports (parallel-residuals, sliding-window eval, GPTQ), not architecture. +- vs `track_non_record_16mb/2026-03-24_106M_Binary_Asymmetric_UNet_FP8_15L...` (1.1239, 8×H100, 2.15h, 8192 BPE): +0.177 BPB +- vs `track_non_record_16mb/2026-03-18_Quasi10Bfrom50B_SP1024_9x512_KV4_4h_pgut3` (1.2074, 8×H100, 4h): +0.093 BPB + +## Files + +- `train_gpt.py` — single-file submission script (106,722 bytes). `BitLinear` plus `pack_ternary` / `unpack_ternary` are inlined; the (gated) `trigram_side_memory` import sites raise `NotImplementedError` so the dead branches are visible-and-unreachable rather than smuggling in a hidden dependency. +- `env.sh` — canonical environment; source from the repo root. +- `train_seed1337.log` — training log (partial: pod was stopped before the full `run.log` synced from `/workspace`; the lines preserved cover the headline numbers — pre/post-quant val_bpb, payload bytes, step times, peak memory, EMA shadow swap). The `code_bytes` / `Total submission size` lines reflect the original three-file run; see "Key metrics" above for the post-cleanup numbers. +- `result.json`, `submission.json` — leaderboard metadata. +- `requirements.txt` — `brotli` and `sentencepiece` are required at quant-export. diff --git a/records/track_non_record_16mb/2026-04-30_KillMamba2_TriParallel_n7_Ternary_EMA_4xH200_1hr/env.sh b/records/track_non_record_16mb/2026-04-30_KillMamba2_TriParallel_n7_Ternary_EMA_4xH200_1hr/env.sh new file mode 100644 index 0000000000..f0fee35d4c --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_KillMamba2_TriParallel_n7_Ternary_EMA_4xH200_1hr/env.sh @@ -0,0 +1,54 @@ +# Canonical environment for this submission. +# +# Run from the REPO ROOT (where data/ lives): +# +# source records/track_non_record_16mb/2026-04-30_KillMamba2_TriParallel_n7_Ternary_EMA_4xH200_1hr/env.sh +# torchrun --standalone --nproc_per_node=4 \ +# records/track_non_record_16mb/2026-04-30_KillMamba2_TriParallel_n7_Ternary_EMA_4xH200_1hr/train_gpt.py +# +# train_gpt.py is fully self-contained — it does NOT import any local helper +# modules. DATA_PATH and TOKENIZER_PATH default to ./data/datasets/fineweb10B_sp1024 +# and ./data/tokenizers/fineweb_1024_bpe.model, both of which resolve from the +# repo root, matching the convention of every other records-folder submission. + +# --- Identity / data --- +export RUN_ID="kill_mamba2_n7_ternary_ema_h200_1hr" +export SEED=1337 +export VOCAB_SIZE=1024 +export TRAIN_SEQ_LEN=1024 + +# --- Topology: 7 unique blocks × 3 weight-shared loops; parallel(attn || kill-Mamba-2) at every position --- +export NUM_UNIQUE_LAYERS=7 +export NUM_LOOPS=3 +export ATTN_LAYER_POSITIONS= +export MAMBA2_LAYER_POSITIONS= +export PARALLEL_LAYER_POSITIONS=0,1,2,3,4,5,6 +export PARALLEL_SSM_TYPE=mamba2_kill +export MAMBA2_KILL_SELECTIVITY=1 +export MLP_TYPE=swiglu +export MLP_MULT=8 +export BIGRAM_VOCAB_SIZE=0 # off — interacts negatively with conv1d already in Mamba-2 + +# --- Quantization: BitNet-b1.58 ternary body, exported 2-bit packed --- +export TERNARY_BODY=1 +# Tensors matching these patterns stay fp32 throughout (1D / small, ≤65,536 elem). +# Load-bearing for SSM dynamics buffers (A_log, B_proj, C_proj, dt_bias, D_skip, conv1d, etc.). +export CONTROL_TENSOR_NAME_PATTERNS="attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,A_log,A_im,B_proj,C_proj,dt_log,D_skip,dt_bias,delta_bias,conv1d" + +# --- EMA-of-weights, β=0.999 (effective averaging window ~1000 steps; shadow swapped before final eval) --- +export EMA_BETA=0.999 + +# --- Schedule / optimizer --- +export TRAIN_BATCH_TOKENS=524288 +export ITERATIONS=20000 # upper bound; the wallclock cap below is the binding limit +export MAX_WALLCLOCK_SECONDS=3600 # 1 hour +export WARMDOWN_ITERS=1800 +export LR_WARMUP_STEPS=30 +export WARMUP_STEPS=0 # batch-size warmup (separate from LR warmup); off +export MATRIX_LR=0.045 +export TIED_EMBED_INIT_STD=0.05 +export MUON_BACKEND_STEPS=15 + +# --- Eval --- +export VAL_TOKENS=0 # 0 = full validation set, writeup-quality +export VAL_LOSS_EVERY=0 # no mid-training val (eval runs once at the end after EMA swap) diff --git a/records/track_non_record_16mb/2026-04-30_KillMamba2_TriParallel_n7_Ternary_EMA_4xH200_1hr/requirements.txt b/records/track_non_record_16mb/2026-04-30_KillMamba2_TriParallel_n7_Ternary_EMA_4xH200_1hr/requirements.txt new file mode 100644 index 0000000000..9a9e81fbdf --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_KillMamba2_TriParallel_n7_Ternary_EMA_4xH200_1hr/requirements.txt @@ -0,0 +1,22 @@ +# RunPod / CUDA runtime requirements for parameter-golf-ssm. +# +# The RunPod "Parameter Golf" template comes with PyTorch + CUDA already +# installed, so torch is intentionally NOT pinned here — we use whatever +# CUDA build the image ships and only layer on the small library set the +# training/quantization path needs. +# +# If you ever provision a non-PG image and need to install torch yourself: +# pip install torch --index-url https://download.pytorch.org/whl/cu121 +# (cu121 / cu124 both work; train_gpt.py only relies on stock PyTorch APIs.) +# +# Mirrors records/track_10min_16mb/2026-03-31_ParallelResiduals_MiniDepthRecurrence/requirements.txt +# plus the SSM-side memory packing path (brotli) and a couple of helpers. + +numpy +tqdm +huggingface-hub +setuptools +typing-extensions==4.15.0 +datasets +sentencepiece +brotli diff --git a/records/track_non_record_16mb/2026-04-30_KillMamba2_TriParallel_n7_Ternary_EMA_4xH200_1hr/result.json b/records/track_non_record_16mb/2026-04-30_KillMamba2_TriParallel_n7_Ternary_EMA_4xH200_1hr/result.json new file mode 100644 index 0000000000..fadd3f5b17 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_KillMamba2_TriParallel_n7_Ternary_EMA_4xH200_1hr/result.json @@ -0,0 +1,27 @@ +{ + "id": "0124_path_a_h200_1hr", + "parent": "0121_path_a_h100_5k", + "created_at": "2026-04-30T03:38:11Z", + "metrics": { + "val_bpb_pre_quant": 1.2983, + "val_bpb_post_quant": 1.30040229, + "val_loss_pre_quant": 2.1921, + "val_loss_post_quant": 2.19567479, + "quant_tax": 0.002102, + "step_avg_ms": 821.94, + "num_steps": 4380.0, + "artifact_bytes": 12074422.0, + "artifact_mb": 12.074, + "code_bytes": 104676.0, + "compression_ratio": 7.7 + }, + "flags": { + "crashed": false, + "size_violation": false, + "has_nan": false, + "exit_code": 0, + "device": "cuda_h200_x4" + }, + "status": "keep", + "description": "Best project result. n=7 long-train SSM at 4×H200 1hr." +} diff --git a/records/track_non_record_16mb/2026-04-30_KillMamba2_TriParallel_n7_Ternary_EMA_4xH200_1hr/submission.json b/records/track_non_record_16mb/2026-04-30_KillMamba2_TriParallel_n7_Ternary_EMA_4xH200_1hr/submission.json new file mode 100644 index 0000000000..f588019e23 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_KillMamba2_TriParallel_n7_Ternary_EMA_4xH200_1hr/submission.json @@ -0,0 +1,51 @@ +{ + "author": "potatonyliu", + "github_id": "potatonyliu", + "name": "kill-Mamba-2 SSM + Ternary + n=7 Depth Recurrence + EMA (1-hour 4×H200)", + "blurb": "First SSM-based entry in either track. kill-Mamba-2 (LTI selectivity, B/C constants) in parallel with attention at every block, NUM_UNIQUE_LAYERS=7 with NUM_LOOPS=3 weight sharing, BitNet-b1.58 ternary body with 2-bit-packed export, EMA-of-weights at β=0.999, brotli on top. Trained 4,380 steps × 524,288 batch ≈ 2.3B tokens in 3,600s on 4×H200 SXM (one hour, non-record). Lands at 1.30040 BPB / 12.08 MB; +0.076 BPB above the naive records-track baseline (1.2244) despite ~9× the compute, so this earns its place as the first SSM entry rather than as a frontier number. The gap to the records-track frontier (~0.19 BPB vs 1.1063) is dominated by training duration + missing standard-stack ports (parallel-residuals, sliding-window eval, GPTQ) rather than architecture.", + "date": "2026-04-30", + "track": "non_record_16mb", + "val_loss": 2.19567479, + "val_bpb": 1.30040229, + "seeds": [1337], + "seed_results": { + "1337": { + "val_loss_pre_quant": 2.1921, + "val_bpb_pre_quant": 1.2983, + "val_loss_post_quant": 2.19567479, + "val_bpb_post_quant": 1.30040229, + "quant_tax": 0.002102, + "artifact_bytes": 12076468, + "code_bytes": 106722, + "payload_bytes": 11969746, + "steps": 4380, + "step_avg_ms": 821.94, + "device": "cuda_h200_x4", + "wallclock_seconds": 3600, + "model_params": 61657752 + } + }, + "comparison_baseline": { + "track_10min_16mb_naive_baseline": { + "val_bpb": 1.2244, + "path": "records/track_10min_16mb/2026-03-17_NaiveBaseline/README.md", + "delta_bpb": 0.0760, + "note": "9L 512d sp1024 GQA-4 tied-emb on 8×H100 10min — this submission is +0.0760 BPB worse despite ~9× the compute" + }, + "track_10min_16mb_mid_tier_frontier": { + "val_bpb": 1.1063, + "path": "records/track_10min_16mb/2026-03-31_ParallelResiduals_MiniDepthRecurrence/README.md", + "delta_bpb": 0.1941, + "note": "Mid-tier records-track entry (PR #1204), not the baseline; gap is mostly duration + missing standard-stack ports" + }, + "track_non_record_16mb_peers": { + "binary_106M_2.15h": { "val_bpb": 1.1239, "path": "records/track_non_record_16mb/2026-03-24_106M_Binary_Asymmetric_UNet_FP8_15L_8192BPE_YaRN_NeoMuon_Smear/README.md" }, + "transformer_4h": { "val_bpb": 1.2074, "path": "records/track_non_record_16mb/2026-03-18_Quasi10Bfrom50B_SP1024_9x512_KV4_4h_pgut3/README.md" } + } + }, + "bytes_total": 12076468, + "code_bytes": 106722, + "payload_bytes": 11969746, + "hardware": "4×H200 SXM 141GB", + "notes": "Non-record because of hardware (4×H200 not 8×H100) and time (1h not 10min). First SSM entry. Eval is the standard root-harness full-window val + int8/brotli quant roundtrip; not sliding-window. Single-seed (1337); the −0.65 BPB delta over the prior local SSM winner is far outside any reasonable noise floor for this architecture family. Score (1.3004) is +0.076 BPB above the naive records-track baseline (1.2244); positioning as 'requested-PR SSM entry' rather than as a competitive number. The train log records pre-cleanup code_bytes (104,676) / total (12,074,422) from the original three-file run; the shipped single-file artifact has code_bytes=106,722 / total=12,076,468 (BitLinear + pack_ternary / unpack_ternary inlined into train_gpt.py; trigram_side_memory.py deleted as gated dead code). Trained model checkpoint is unchanged — only code-side accounting moved." +} diff --git a/records/track_non_record_16mb/2026-04-30_KillMamba2_TriParallel_n7_Ternary_EMA_4xH200_1hr/train_gpt.py b/records/track_non_record_16mb/2026-04-30_KillMamba2_TriParallel_n7_Ternary_EMA_4xH200_1hr/train_gpt.py new file mode 100644 index 0000000000..d1309c958a --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_KillMamba2_TriParallel_n7_Ternary_EMA_4xH200_1hr/train_gpt.py @@ -0,0 +1,2293 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import contextlib +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +import brotli +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split + # unless VAL_TOKENS > 0, which truncates the val set for faster dev iteration (off by default). + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + val_tokens_cap = int(os.environ.get("VAL_TOKENS", 0)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + # Linear LR warmup applied in lr_mul(). Distinct from warmup_steps, which + # primes compile/kernel paths via a dry-run reset-to-init. + lr_warmup_steps = int(os.environ.get("LR_WARMUP_STEPS", 0)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Depth recurrence (0056): when num_unique_layers > 0, build that many distinct + # Block instances and loop them num_loops times in forward(). Effective depth + # equals num_unique_layers * num_loops. Disables U-Net skip-weights in that + # mode. When num_unique_layers == 0 (default), behavior is identical to canonical. + num_unique_layers = int(os.environ.get("NUM_UNIQUE_LAYERS", 0)) + num_loops = int(os.environ.get("NUM_LOOPS", 1)) + attn_layer_positions = os.environ.get("ATTN_LAYER_POSITIONS", "") + # Mamba-2 / SSD mixer positions (0032): per-block indices that should run + # a Mamba-2 SSD chunkwise selective-SSM block instead of S4D-Lin / + # attention. Mutually exclusive with attn_layer_positions. + mamba2_layer_positions = os.environ.get("MAMBA2_LAYER_POSITIONS", "") + # Hymba-strict parallel topology (0025/0027): positions where the block + # runs CausalSelfAttention AND an SSM in parallel on the same attn_norm + # input, summing independently-scaled contributions. The SSM type is + # selected by PARALLEL_SSM_TYPE (default s4d_lin -> S4DLin; mamba2 or + # mamba2_kill -> Mamba2Block, kill flag inherited from + # MAMBA2_KILL_SELECTIVITY). Mutually exclusive with attn_layer_positions + # and mamba2_layer_positions. + parallel_layer_positions = os.environ.get("PARALLEL_LAYER_POSITIONS", "") + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # BigramHash bolt-on (0018): hashes adjacent token pairs into a small extra + # embedding added to tok_emb. When BIGRAM_VOCAB_SIZE=0 (default), the module + # is not constructed and behavior is identical to the baseline. + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 0)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 64)) + + # Trigram side-memory (0067). When TRIGRAM_SIDE_MEMORY=0 (default), + # behavior is byte-identical to parent 0064: no buffers added, forward + # path unchanged. When =1, post-training we build a K-gram + bigram + # backoff predictor from the first TRIGRAM_BUILD_TOKENS train tokens, + # register the pack as model buffers, and have model.forward blend + # model log-probs with the K-gram predictor in log-prob space at eval. + # + # 0068 4-gram extension: + # TRIGRAM_K (default "3"): selects context order. "3" = original 0067 + # trigram behavior; "4" = 4-gram (build top-N contexts, lookup at t>=3). + # + # 0069 multi-K extension: + # TRIGRAM_K is now comma-separated. "3" / "4" are single-K (byte- + # identical to 0067 / 0068 respectively). "3,4" enables COMBINED: + # build BOTH a K=3 and a K=4 pack, and at forward time blend + # model + K=3 + K=4 in log-prob space (3-way logsumexp). + # TRIGRAM_TOP_N_CTX_K3 (default 100000): top-N for K=3 pack (multi-K). + # TRIGRAM_TOP_N_CTX_K4 (default 200000): top-N for K=4 pack (multi-K). + # TRIGRAM_BLEND_WEIGHTS (default "0.8,0.2" single-K, "0.7,0.10,0.20" + # two-K). Comma-separated; first weight is the model, the rest are + # per-K in TRIGRAM_K's order. Must sum to 1.0. + # For backward compat, when TRIGRAM_K is a single value: + # - TRIGRAM_TOP_N_CTX (default 0) is honored (used by 0067/0068); + # if both TRIGRAM_TOP_N_CTX and the K-specific override are set, + # the K-specific one wins. + # - TRIGRAM_BLEND_ALPHA (default 0.80) is honored as the model + # weight if TRIGRAM_BLEND_WEIGHTS is not explicitly set. + trigram_side_memory = int(os.environ.get("TRIGRAM_SIDE_MEMORY", 0)) + trigram_top_k = int(os.environ.get("TRIGRAM_TOP_K", 2)) + trigram_blend_alpha = float(os.environ.get("TRIGRAM_BLEND_ALPHA", 0.80)) + trigram_build_tokens = int(os.environ.get("TRIGRAM_BUILD_TOKENS", 100_000_000)) + trigram_min_count = int(os.environ.get("TRIGRAM_MIN_COUNT", 2)) + # TRIGRAM_K is a comma-separated list of ints. Default "3" (single K=3). + _trigram_K_str = os.environ.get("TRIGRAM_K", "3") + trigram_K_list = [int(x.strip()) for x in _trigram_K_str.split(",") if x.strip()] + if not trigram_K_list: + trigram_K_list = [3] + # Single-K back-compat: trigram_K is the int (used by 0067/0068 paths). + trigram_K = trigram_K_list[0] if len(trigram_K_list) == 1 else 0 + # Per-K top-N overrides for multi-K. Single-K honors TRIGRAM_TOP_N_CTX. + trigram_top_n_ctx = int(os.environ.get("TRIGRAM_TOP_N_CTX", 0)) + trigram_top_n_ctx_k3 = int(os.environ.get("TRIGRAM_TOP_N_CTX_K3", 100_000)) + trigram_top_n_ctx_k4 = int(os.environ.get("TRIGRAM_TOP_N_CTX_K4", 200_000)) + # Blend weights. If unset, derive from TRIGRAM_BLEND_ALPHA for single-K + # (preserving 0067/0068 behavior) or default to (0.7, 0.10, 0.20) for + # multi-K K=3,4 (matching the offline reference). + _trigram_blend_weights_str = os.environ.get("TRIGRAM_BLEND_WEIGHTS", "") + if _trigram_blend_weights_str: + trigram_blend_weights = [ + float(x.strip()) for x in _trigram_blend_weights_str.split(",") if x.strip() + ] + elif len(trigram_K_list) == 1: + # Single-K default: (alpha, 1 - alpha). Byte-identical to 0067/0068. + trigram_blend_weights = [trigram_blend_alpha, 1.0 - trigram_blend_alpha] + else: + # Multi-K K=3,4 default: (model 0.7, K=3 0.10, K=4 0.20). + trigram_blend_weights = [0.7, 0.10, 0.20] + # Validate + if len(trigram_blend_weights) != 1 + len(trigram_K_list): + raise ValueError( + f"TRIGRAM_BLEND_WEIGHTS has {len(trigram_blend_weights)} weights but " + f"TRIGRAM_K has {len(trigram_K_list)} K's; expected " + f"{1 + len(trigram_K_list)} (model + per-K)" + ) + _w_sum = sum(trigram_blend_weights) + if abs(_w_sum - 1.0) > 1e-4: + raise ValueError( + f"TRIGRAM_BLEND_WEIGHTS must sum to 1.0, got {_w_sum:.6f} ({trigram_blend_weights})" + ) + + # 0074 per-context α: when PER_CONTEXT_ALPHA=1, derive the model blend + # weight per (b, t) position from the trigram's per-context entropy. Pack + # adds ~0.2 MB int8 α buffer per K. Default 0 → byte-identical to parent + # 0069 (no alpha buffers built/registered, multi-K blend uses fixed weights). + per_context_alpha = int(os.environ.get("PER_CONTEXT_ALPHA", 0)) + alpha_tau = float(os.environ.get("ALPHA_TAU", 0.5)) + alpha_thresh = float(os.environ.get("ALPHA_THRESH", 5.0)) + alpha_min = float(os.environ.get("ALPHA_MIN", 0.5)) + alpha_max = float(os.environ.get("ALPHA_MAX", 0.95)) + + # 0076 confidence-gated blend: when set above the default, positions where + # the model's max log2-prob exceeds this threshold (model is confident) + # skip the trigram blend and use model log-probs alone. -1e9 (default) = + # gate disabled, byte-identical to parent 0074. + conf_gate_threshold = float(os.environ.get("CONF_GATE_THRESHOLD", -1e9)) + + # 0116 EMA-of-weights. Default EMA_BETA=0.0 disables (byte-identical to + # parent 0107). When > 0, after EMA_WARMUP_OFFSET completed steps we run + # w_ema = beta*w_ema + (1-beta)*w_current (fp32) on a separate shadow dict; + # at last_step we swap the shadow into model params for the final eval + + # quant export. Pre-warmup the shadow tracks current (no decay; avoids + # averaging random init). EMA_WARMUP_OFFSET="" (default) -> lr_warmup_steps. + ema_beta: float = float(os.environ.get("EMA_BETA", "0.0")) + ema_warmup_offset_raw: str = os.environ.get("EMA_WARMUP_OFFSET", "") + ema_warmup_offset: int = lr_warmup_steps if ema_warmup_offset_raw == "" else int(ema_warmup_offset_raw) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + autocast_ctx, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + # MPS does not support float64 and has indexing bugs; run eval accounting + # on CPU when not on CUDA. No-op on CUDA (accum_device == device). + accum_device = device if device.type == "cuda" else torch.device("cpu") + val_loss_sum = torch.zeros((), device=accum_device, dtype=torch.float64) + val_token_count = torch.zeros((), device=accum_device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=accum_device, dtype=torch.float64) + base_bytes_lut_acc = base_bytes_lut.to(accum_device) + has_leading_space_lut_acc = has_leading_space_lut.to(accum_device) + is_boundary_token_lut_acc = is_boundary_token_lut.to(accum_device) + + model.eval() + with torch.inference_mode(): + batch_starts = range(seq_start, seq_end, local_batch_seqs) + for batch_seq_start in batch_starts: + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + # MPS corrupts the first element of non-zero-offset int64 views when + # copy is non_blocking; use a blocking copy for validation tokens. + local = val_tokens[raw_start:raw_end].to(dtype=torch.int64).to(device=device) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with autocast_ctx: + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(accum_device).to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1).to(accum_device) + tgt_ids = y.reshape(-1).to(accum_device) + token_bytes = base_bytes_lut_acc[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut_acc[tgt_ids] & ~is_boundary_token_lut_acc[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + # 0067: keep trigram side-memory scales/offsets and unigram/bigram log2p + # tables at fp32 (no fp16 downcast on save). Their values decode int8 + # buffers, so even small fp16 rounding shifts every probability and + # poisons the blend. Cost: a few KB extra in the artifact. + ",".join(CONTROL_TENSOR_NAME_PATTERNS) + + ",trigram_log2p_scale,trigram_log2p_offset" + + ",trigram_log2_backoff_scale,trigram_log2_backoff_offset" + # 0069 multi-K: K-suffixed scale/offset buffers (trigram3_*, trigram4_*). + + ",trigram3_log2p_scale,trigram3_log2p_offset" + + ",trigram3_log2_backoff_scale,trigram3_log2_backoff_offset" + + ",trigram4_log2p_scale,trigram4_log2p_offset" + + ",trigram4_log2_backoff_scale,trigram4_log2_backoff_offset" + + ",bigram_log2p_scales,bigram_log2p_offsets" + + ",unigram_log2p", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor], bitlinear_weight_names: set[str] = frozenset()): + # 0086 v2: BitLinear weight names get packed-ternary 2-bit storage instead of int8. + # Single supported clean-script export format: + # - packed-ternary for names in bitlinear_weight_names (4 vals/byte) + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + packed_ternary: dict[str, Tensor] = {} + packed_ternary_meta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # 0086: pack BitLinear weights at 2 bits/param. Detected by name set. + if name in bitlinear_weight_names: + packed, gamma = pack_ternary(t) + packed_ternary[name] = packed + packed_ternary_meta[name] = { + "gamma": float(gamma), + "shape": list(t.shape), + "dtype": str(t.dtype).removeprefix("torch."), + } + stats["int8_payload_bytes"] += tensor_nbytes(packed) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1+packed_ternary", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + if packed_ternary: + obj["packed_ternary"] = packed_ternary + obj["packed_ternary_meta"] = packed_ternary_meta + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + # 0086 v2: also restores packed-ternary weights via unpack_ternary. + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + # 0086: unpack ternary-packed weights. + packed_ternary = obj.get("packed_ternary", {}) + packed_ternary_meta = obj.get("packed_ternary_meta", {}) + for name, packed in packed_ternary.items(): + meta = packed_ternary_meta[name] + gamma = meta["gamma"] + shape = tuple(meta["shape"]) + dtype = getattr(torch, meta["dtype"]) + out[name] = unpack_ternary(packed, gamma, shape).to(dtype=dtype).contiguous() + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +# 0083: BitNet b1.58 ternary BodyLinear. When TERNARY_BODY=1, body matmuls +# (attn qkv/proj, mamba2 in/out_proj, MLP gates) use ternary {-γ,0,+γ} +# weights with absmean γ + STE. Forward always ternarizes; quant export +# packs ternary at ~2 bits/param. Reference: Ma et al. 2024, "The Era of +# 1-bit LLMs", arxiv:2402.17764, eq (1)-(2). Inlined here so the submission +# is single-file; pack_ternary/unpack_ternary used by quantize_state_dict_int8 +# below. Encoding offset (+1): -1→0b00, 0→0b01, +1→0b10. 4 vals per byte. + +class BitLinear(nn.Linear): + """BitNet b1.58 ternary weight Linear layer. Drop-in for nn.Linear.""" + + def forward(self, x: Tensor) -> Tensor: + w = self.weight + gamma = w.abs().mean().clamp(min=1e-8) + w_q = (w / gamma).round().clamp_(-1.0, 1.0) + w_ste = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, (gamma * w_ste).to(x.dtype), bias) + + +def pack_ternary(w: Tensor) -> tuple[Tensor, float]: + """Ternarize w via absmean γ, pack into uint8 (4 vals/byte).""" + w_flat = w.detach().contiguous().view(-1).float() + gamma = w_flat.abs().mean().clamp(min=1e-8).item() + w_q = (w_flat / gamma).round().clamp_(-1.0, 1.0).to(torch.int8) + codes = (w_q + 1).to(torch.uint8) + n = codes.numel() + pad = (4 - n % 4) % 4 + if pad: + codes = torch.cat([codes, torch.zeros(pad, dtype=torch.uint8)]) + quad = codes.view(-1, 4) + packed = ( + quad[:, 0] | (quad[:, 1] << 2) | (quad[:, 2] << 4) | (quad[:, 3] << 6) + ).contiguous() + return packed, gamma + + +def unpack_ternary(packed: Tensor, gamma: float, shape) -> Tensor: + """Inverse of pack_ternary. Scales so BitLinear's forward-time recompute + of γ' = abs.mean(W) recovers the original γ — keeps post-quant output + bit-equivalent to training.""" + n_total = 1 + for s in shape: + n_total *= s + p = packed.to(torch.uint8) + c0 = p & 0b11 + c1 = (p >> 2) & 0b11 + c2 = (p >> 4) & 0b11 + c3 = (p >> 6) & 0b11 + codes = torch.stack([c0, c1, c2, c3], dim=1).view(-1) + codes = codes[:n_total].to(torch.int8) + w_q = codes.to(torch.float32) - 1.0 + n_nonzero = (w_q != 0).sum().item() + frac_nonzero = max(n_nonzero / max(n_total, 1), 1e-8) + alpha = gamma / frac_nonzero + return (w_q * alpha).view(*shape) + + +if os.environ.get("TERNARY_BODY", "0") == "1": + BodyLinear = BitLinear +else: + BodyLinear = CastedLinear + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = BodyLinear(dim, dim, bias=False) + self.c_k = BodyLinear(dim, kv_dim, bias=False) + self.c_v = BodyLinear(dim, kv_dim, bias=False) + self.proj = BodyLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class S4DLin(nn.Module): + """S4D-Lin (LTI diagonal SSM, FFT-conv) drop-in replacement for attention. + + Input/output shape: (B, L, dim). SSM dynamics held in fp32; the input + branch (in_proj output) is cast to fp32 for the FFT conv and skip, then + out_proj returns to the input dtype. Matches the recurrence to ~6e-08 on + the tiny case (see scratch/s4d_lin_tiny.py). + """ + + def __init__(self, dim: int, d_state: int = 16, expand: int = 1): + super().__init__() + self.dim = dim + self.d_inner = expand * dim + self.d_state = d_state + self.in_proj = CastedLinear(dim, self.d_inner, bias=False) + self.out_proj = CastedLinear(self.d_inner, dim, bias=False) + self.out_proj._zero_init = True + # A_re = -exp(A_log) initialized to -0.5 everywhere (S4D-Lin real part). + self.A_log = nn.Parameter( + torch.full((self.d_inner, d_state), float(math.log(0.5)), dtype=torch.float32) + ) + # A_im[d, n] = pi * n, broadcast across the d_inner axis. + a_im_init = ( + (math.pi * torch.arange(d_state, dtype=torch.float32)) + .unsqueeze(0) + .expand(self.d_inner, d_state) + .contiguous() + .clone() + ) + self.A_im = nn.Parameter(a_im_init) + self.B_proj = nn.Parameter( + torch.randn(self.d_inner, d_state, dtype=torch.float32) / math.sqrt(d_state) + ) + self.C_proj = nn.Parameter( + torch.randn(self.d_inner, d_state, dtype=torch.float32) / math.sqrt(d_state) + ) + # dt log-uniform in [0.001, 0.1], stored via inverse-softplus so + # F.softplus(dt_log) recovers dt. + u_dt = torch.rand(self.d_inner, dtype=torch.float32) + dt_init = torch.exp(u_dt * (math.log(0.1) - math.log(0.001)) + math.log(0.001)) + self.dt_log = nn.Parameter(torch.log(torch.expm1(dt_init))) + self.D_skip = nn.Parameter(torch.ones(self.d_inner, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + bsz, L, _ = x.shape + u = self.in_proj(x) # (B, L, d_inner) + # Build kernel in fp32 (numerics matter here). + A_re = -torch.exp(self.A_log) # (D, N) fp32 + A = torch.complex(A_re, self.A_im) # (D, N) complex64 + dt = F.softplus(self.dt_log) # (D,) fp32 + log_A_bar = dt[:, None] * A # (D, N) complex64 + A_bar = torch.exp(log_A_bar) # (D, N) complex64, |A_bar| < 1 since A_re < 0 + # B_bar = (A_bar - 1) / A * B (exact ZOH for B, primer 1.2) + B_bar = ((A_bar - 1.0) / A) * self.B_proj.to(A.dtype) + coefs = self.C_proj.to(A.dtype) * B_bar # (D, N) complex64 + powers = torch.arange(L, device=A.device, dtype=A_re.dtype) + log_kernel = log_A_bar[:, :, None] * powers[None, None, :] # (D, N, L) complex64 + K = 2.0 * (coefs[:, :, None] * torch.exp(log_kernel)).sum(dim=1).real # (D, L) fp32 + # FFT conv per channel, padded to 2L for linear (causal) conv. + nfft = 2 * L + u_t = u.transpose(1, 2).contiguous().float() # (B, d_inner, L) fp32 + Uf = torch.fft.rfft(u_t, n=nfft) + Kf = torch.fft.rfft(K, n=nfft) # (D, nfft//2+1) + Yf = Uf * Kf[None, :, :] + y = torch.fft.irfft(Yf, n=nfft)[..., :L] # (B, d_inner, L) + y = y.transpose(1, 2).contiguous() # (B, L, d_inner) + y = y + self.D_skip * u.float() # D skip + return self.out_proj(y.to(x.dtype)) + + +# ----------------------------------------------------------------------------- +# Mamba-2 / SSD (State Space Duality) block (0032) +# +# Adapted from the official minimal SSD reference (state-spaces/mamba, Apache +# License 2.0, Copyright (c) 2024 Albert Gu and Tri Dao). +# Source: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/ssd_minimal.py +# Fetched: 2026-04-26 +# +# This is the chunkwise SSD algorithm from Listing 1 of the Mamba-2 paper +# (Dao & Gu, "Transformers are SSMs: Generalized Models and Efficient +# Algorithms Through Structured State Space Duality", arXiv:2405.21060). +# +# Local modifications from the vendored reference: +# - Wrapped in a Mamba2Block nn.Module that exposes the (B, L, dim) -> +# (B, L, dim) contract used by the rest of the model body. +# - Constructor takes (dim, d_state, expand, chunk_size, headdim) directly. +# - Inner projections use the worktree's CastedLinear (fp32 weights, bf16 +# compute). +# - Per-channel skip is named `D_skip` (instead of upstream `D`) so the +# existing CONTROL_TENSOR_NAME_PATTERNS substring `D_skip` keeps the +# param fp32 without overmatching. +# - `segsum` rewritten without einops.repeat (we have no einops dep) using +# plain torch ops; equivalent semantics to the upstream stable form. +# - Removed einops dependency in the chunkwise scan body; einsum/reshape +# via torch.reshape and explicit dim args. +# - SSD chunkwise scan is run in fp32 throughout (A, dt, B, C, X all cast +# to fp32 inside ssd_minimal_discrete) for numerical stability on MPS. +# - ngroups=1 (single shared (B, C) per chunk position, broadcast across +# heads) -- this is the most common SSD configuration. +# - dt_bias initialized via softplus^-1 of a log-uniform sample in +# [dt_min, dt_max] (matches Mamba-2 paper §3.6 / official repo). +# - A_log initialized to log(1) = 0 so initial A = -1 per head (uniform, +# stable). The official mamba_ssm/modules/mamba2.py uses +# log(uniform(1, 16)); we use a single value for simplicity at our +# 16-head scale. +# ----------------------------------------------------------------------------- + + +def _ssd_segsum(x: Tensor) -> Tensor: + """Stable segment-sum (cumulative sum within lower-triangular blocks). + + Adapted from ssd_minimal.py `segsum` (state-spaces/mamba). For input + shape (..., T), returns (..., T, T) where output[..., i, j] equals + sum_{k=j+1}^{i} x[..., k] for i >= j else -inf. + Equivalent to taking cumsum over j..i exclusive of j; -inf entries are + later exponentiated to 0 by the caller (segment-decay mask). + """ + T = x.size(-1) + # Replicate x along a new last dim of size T -> (..., T, T). + x_expanded = x.unsqueeze(-1).expand(*x.shape, T) + # Zero out entries strictly above the strict-lower triangle so the + # cumulative sum along dim=-2 only accumulates between strictly-below + # diagonal positions (matches segsum semantics). + mask_keep = torch.tril( + torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=-1 + ) + x_masked = x_expanded.masked_fill(~mask_keep, 0.0) + x_segsum = torch.cumsum(x_masked, dim=-2) + # Above-diagonal entries are not valid segment sums; mark with -inf so + # exp(...) later gives 0 (proper causal/segment mask). + mask_valid = torch.tril( + torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=0 + ) + x_segsum = x_segsum.masked_fill(~mask_valid, float("-inf")) + return x_segsum + + +def ssd_minimal_discrete( + X: Tensor, + A: Tensor, + B: Tensor, + C: Tensor, + block_len: int, + initial_states: Tensor | None = None, +) -> tuple[Tensor, Tensor]: + """Minimal SSD chunkwise scan (Listing 1 of Mamba-2 paper). + + Adapted verbatim (modulo the einops removal) from + state-spaces/mamba/mamba_ssm/modules/ssd_minimal.py. + + Inputs (already discretized -- caller multiplies X*=dt and A*=dt before + invoking, matching the canonical reference's calling convention): + X: (b, L, n_heads, d_head) -- dt-scaled inputs + A: (b, L, n_heads) -- dt-scaled scalar decay per head + B: (b, L, n_heads, d_state) -- input-dependent (broadcast for + ngroups=1: same across heads) + C: (b, L, n_heads, d_state) -- output-dependent (likewise) + Returns: + Y: (b, L, n_heads, d_head) + final_state: (b, n_heads, d_head, d_state) + Requires L % block_len == 0. + """ + assert X.dtype == A.dtype == B.dtype == C.dtype + b, L, h, p = X.shape + n = B.size(-1) + if L % block_len != 0: + raise ValueError(f"L={L} must be divisible by block_len={block_len}") + c_count = L // block_len + + # Rearrange (b, L, ...) -> (b, c, l, ...) where l = block_len. + X = X.reshape(b, c_count, block_len, h, p) + A = A.reshape(b, c_count, block_len, h) + B = B.reshape(b, c_count, block_len, h, n) + C = C.reshape(b, c_count, block_len, h, n) + + # A in (b, h, c, l) layout for cumsum/segsum along the within-chunk axis. + A_perm = A.permute(0, 3, 1, 2).contiguous() # (b, h, c, l) + A_cumsum = torch.cumsum(A_perm, dim=-1) + + # 1. Intra-chunk diagonal-block output. + # L_intra[..., l, s] = exp( segsum_l - segsum_s ) (l >= s, else 0). + L_intra = torch.exp(_ssd_segsum(A_perm)) # (b, h, c, l, s) + Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L_intra, X) + + # 2. Compute the SSM state contributed by each chunk (right factor of the + # low-rank off-diagonal blocks; B-side). + decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum) # (b, h, c, l) + states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X) + + # 3. Inter-chunk recurrence (the only non-parallel step). Pad an initial + # state (zeros if not supplied) and do a chunk-axis segsum on the + # cumulative chunk-end decays. + if initial_states is None: + initial_states = torch.zeros_like(states[:, :1]) + states = torch.cat([initial_states, states], dim=1) + chunk_end_A = A_cumsum[:, :, :, -1] # (b, h, c) + chunk_end_A = F.pad(chunk_end_A, (1, 0)) # (b, h, c+1) leading 0 + decay_chunk = torch.exp(_ssd_segsum(chunk_end_A)) # (b, h, c+1, c+1) + new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) + states_per_chunk, final_state = new_states[:, :-1], new_states[:, -1] + + # 4. State -> output conversion per chunk (left factor; C-side). + state_decay_out = torch.exp(A_cumsum) # (b, h, c, l) + Y_off = torch.einsum( + "bclhn,bchpn,bhcl->bclhp", C, states_per_chunk, state_decay_out + ) + + Y = (Y_diag + Y_off).reshape(b, L, h, p) + return Y, final_state + + +class Mamba2Block(nn.Module): + """Mamba-2 (Dao & Gu 2024, arXiv:2405.21060) SSD selective-SSM mixer. + + Drop-in replacement for an attention/S4D mixer: takes (B, L, dim) and + returns (B, L, dim). Internal width is d_inner = expand * dim, split into + nheads = d_inner // headdim heads each of dimension headdim. The SSD + chunkwise scan replaces Mamba-1's sequential selective scan: most of the + work is matmul (intra-chunk diagonal blocks) so it runs much better on + MPS than the per-step Python loop in MambaBlock. + """ + + def __init__( + self, + dim: int, + d_state: int = 64, + expand: int = 2, + chunk_size: int = 64, + headdim: int = 64, + d_conv: int = 4, + conv_bias: bool = True, + ): + super().__init__() + self.dim = dim + self.expand = expand + self.d_inner = int(expand * dim) + if self.d_inner % headdim != 0: + raise ValueError( + f"d_inner={self.d_inner} must be divisible by headdim={headdim}" + ) + self.headdim = headdim + self.nheads = self.d_inner // headdim + self.d_state = d_state + self.chunk_size = chunk_size + self.d_conv = d_conv + + # in_proj produces [x_branch, z_gate, B, C, dt] in a single matmul. + # Shapes: x_branch d_inner, z d_inner, B d_state (ngroups=1 shared), + # C d_state (ngroups=1 shared), dt nheads. + self.in_proj = BodyLinear( + dim, + 2 * self.d_inner + 2 * d_state + self.nheads, + bias=False, + ) + + # Causal depthwise conv on x branch (matches Mamba-1). + self.conv1d = nn.Conv1d( + in_channels=self.d_inner, + out_channels=self.d_inner, + bias=conv_bias, + kernel_size=d_conv, + groups=self.d_inner, + padding=d_conv - 1, + ) + + # Output projection. Zero-init so block is a no-op at start. + self.out_proj = BodyLinear(self.d_inner, dim, bias=False) + self.out_proj._zero_init = True + + # SSD parameters (per-head). + # A_log: A = -exp(A_log) per head (scalar A_t = a_t I, with a_t<0 + # discount). Initialized to log(1)=0 -> A=-1 (matches the official + # mamba_ssm A_init's lower bound). + self.A_log = nn.Parameter(torch.zeros(self.nheads, dtype=torch.float32)) + # Per-channel skip term (the "D" in y = SSM(x) + D*x). + self.D_skip = nn.Parameter(torch.ones(self.d_inner, dtype=torch.float32)) + # dt bias: softplus^-1 of log-uniform dt sample on [1e-3, 1e-1] + # (Mamba paper §3.6 init; pulled forward to Mamba-2). + dt_min, dt_max = 1e-3, 1e-1 + with torch.no_grad(): + dt = torch.exp( + torch.rand(self.nheads, dtype=torch.float32) + * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ).clamp(min=1e-4) + inv_dt = dt + torch.log(-torch.expm1(-dt)) # softplus^-1 + self.dt_bias = nn.Parameter(inv_dt) + + # 0038: selectivity-kill ablation. When MAMBA2_KILL_SELECTIVITY=1, + # replace input-dependent (dt, B, C) with learned per-head/per-state + # constants in forward. Keeps the rest of the block (conv1d, in_proj, + # out_proj, D_skip, A_log, dt_bias, gate via z) identical to full + # Mamba-2 — only the dynamics become LTI. This is the cleanest test + # of "is selectivity load-bearing?" — if val ≈ 2.04 the win is + # parameters/structure, if val ≈ 2.16 selectivity is the mechanism. + # The B/C/dt projections inside in_proj remain (slight param waste + # for cleaner ablation framing — same in_proj shape as full Mamba-2). + self._kill_selectivity = os.environ.get("MAMBA2_KILL_SELECTIVITY", "0") == "1" + if self._kill_selectivity: + # Per-state constants for B and C (broadcast across batch+time). + self._B_const = nn.Parameter( + torch.randn(d_state, dtype=torch.float32) / math.sqrt(d_state) + ) + self._C_const = nn.Parameter( + torch.randn(d_state, dtype=torch.float32) / math.sqrt(d_state) + ) + # dt: only dt_bias contributes (no per-token dt projection). + # No new param needed — dt_bias is reused. + + def forward(self, x: Tensor) -> Tensor: + b, l, _ = x.shape + if l % self.chunk_size != 0: + raise ValueError( + f"Mamba2Block: seq_len={l} must be divisible by chunk_size={self.chunk_size}" + ) + in_dtype = x.dtype + + proj = self.in_proj(x) + # Split: x (d_inner), z (d_inner), B (d_state), C (d_state), dt (nheads). + x_branch, z, B_in, C_in, dt = proj.split( + [self.d_inner, self.d_inner, self.d_state, self.d_state, self.nheads], + dim=-1, + ) + + # 0038 selectivity-kill: override dt/B/C with learned constants + # (broadcast across batch+time). Equivalent to LTI dynamics with the + # same surrounding structure as full Mamba-2. + if self._kill_selectivity: + dt = torch.zeros_like(dt) # dt_bias-only contribution after softplus + B_in = self._B_const.view(1, 1, self.d_state).expand(b, l, self.d_state).to(B_in.dtype) + C_in = self._C_const.view(1, 1, self.d_state).expand(b, l, self.d_state).to(C_in.dtype) + + # Causal depthwise conv on x_branch. Run conv1d in fp32 (weight is + # restored to fp32 via CONTROL_TENSOR_NAME_PATTERNS substring 'conv1d', + # bias is auto-fp32 as 1D). Cast back after. + x_conv = x_branch.transpose(1, 2).float() # (b, d_inner, l) + x_conv = self.conv1d(x_conv)[:, :, :l] + x_conv = x_conv.transpose(1, 2).contiguous() # (b, l, d_inner) + x_act = F.silu(x_conv) + + # Reshape x_act to per-head (b, l, nheads, headdim). + x_h = x_act.reshape(b, l, self.nheads, self.headdim) + + # Compute (dt, A) for SSD. dt = softplus(dt_in + dt_bias); + # A = -exp(A_log). + dt_f = F.softplus(dt.float() + self.dt_bias.float()) # (b, l, nheads) + A_f = -torch.exp(self.A_log.float()) # (nheads,) + + # Discretize: SSD reference takes X*=dt and A*=dt as inputs. Applied + # in fp32 throughout for stability. + X_disc = x_h.float() * dt_f.unsqueeze(-1) # (b, l, nheads, headdim) + A_disc = A_f[None, None, :] * dt_f # (b, l, nheads) + + # B and C are shared across heads (ngroups=1) so we broadcast a + # singleton head axis. + B_b = B_in.float().unsqueeze(2).expand(b, l, self.nheads, self.d_state) + C_b = C_in.float().unsqueeze(2).expand(b, l, self.nheads, self.d_state) + + Y, _ = ssd_minimal_discrete( + X_disc, A_disc, B_b, C_b, block_len=self.chunk_size + ) # (b, l, nheads, headdim) fp32 + + y = Y.reshape(b, l, self.d_inner) + + # D skip + gate. y, x_act, z all currently in fp32 / in_dtype mix; + # do the gate in fp32 then cast back. + y = y + self.D_skip.float() * x_act.float() + y = y * F.silu(z.float()) + return self.out_proj(y.to(in_dtype)) + + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class MLP(nn.Module): + # MLP_TYPE=relu2 (default): relu^2 MLP from modded-nanogpt + # MLP_TYPE=swiglu: SwiGLU (silu-gated up-proj, then down-proj) + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self._mlp_type = os.environ.get("MLP_TYPE", "relu2") + if self._mlp_type == "swiglu": + self.w_gate = BodyLinear(dim, hidden, bias=False) + self.w_up = BodyLinear(dim, hidden, bias=False) + self.w_down = BodyLinear(hidden, dim, bias=False) + self.w_down._zero_init = True + else: + self.fc = BodyLinear(dim, hidden, bias=False) + self.proj = BodyLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + if self._mlp_type == "swiglu": + return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)) + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_attention: bool = False, + mamba2_mode: bool = False, + parallel_mode: bool = False, + ): + super().__init__() + # The three mixer modes (use_attention, mamba2_mode, parallel_mode) + # all populate self.attn (and parallel_mode additionally constructs + # self.s4d), so they must be mutually exclusive. + mode_count = int(use_attention) + int(mamba2_mode) + int(parallel_mode) + if mode_count > 1: + raise ValueError( + "Block: at most one of use_attention, mamba2_mode, parallel_mode " + f"may be set, got use_attention={use_attention}, " + f"mamba2_mode={mamba2_mode}, parallel_mode={parallel_mode}." + ) + self.parallel_mode = parallel_mode + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + if parallel_mode: + # Hymba-strict (0025/0027): run attention AND an SSM in parallel + # on the same attn_norm input, summing independently-scaled + # contributions. The SSM type is selected by PARALLEL_SSM_TYPE + # (default s4d_lin preserves byte-identical behavior to 0027). + # When set to mamba2 or mamba2_kill, the parallel SSM becomes a + # Mamba2Block; the kill-vs-full distinction then follows + # MAMBA2_KILL_SELECTIVITY (read inside Mamba2Block.__init__), + # so no extra plumbing is needed here. + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + parallel_ssm_type = os.environ.get("PARALLEL_SSM_TYPE", "s4d_lin") + if parallel_ssm_type == "s4d_lin": + self.s4d = S4DLin(dim, d_state=16, expand=1) + elif parallel_ssm_type in ("mamba2", "mamba2_kill"): + self.s4d = Mamba2Block(dim, d_state=64, expand=2, chunk_size=64, headdim=64) + else: + raise ValueError( + f"PARALLEL_SSM_TYPE={parallel_ssm_type!r} not recognized; " + f"expected one of: s4d_lin, mamba2, mamba2_kill." + ) + self.s4d_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + elif use_attention: + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + elif mamba2_mode: + # Mamba-2 / SSD selective-SSM mixer (0032). Same forward contract + # as CausalSelfAttention / S4DLin: (B, L, dim) -> (B, L, dim). + self.attn = Mamba2Block(dim, d_state=64, expand=2, chunk_size=64, headdim=64) + else: + self.attn = S4DLin(dim, d_state=16, expand=1) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + normed = self.attn_norm(x) + if self.parallel_mode: + attn_out = self.attn(normed) + s4d_out = self.s4d(normed) + x = ( + x + + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + + self.s4d_scale.to(dtype=x.dtype)[None, None, :] * s4d_out + ) + else: + attn_out = self.attn(normed) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_unique_layers: int = 0, + num_loops: int = 1, + attn_layer_positions: set[int] = None, + mamba2_layer_positions: set[int] = None, + parallel_layer_positions: set[int] = None, + bigram_vocab_size: int = 0, + bigram_dim: int = 64, + ): + super().__init__() + attn_positions = attn_layer_positions or set() + mamba2_positions = mamba2_layer_positions or set() + parallel_positions = parallel_layer_positions or set() + if attn_positions & mamba2_positions: + raise ValueError( + f"attn_layer_positions and mamba2_layer_positions must be disjoint, " + f"got overlap {sorted(attn_positions & mamba2_positions)} " + f"(attn={sorted(attn_positions)}, mamba2={sorted(mamba2_positions)})" + ) + if attn_positions & parallel_positions: + raise ValueError( + f"attn_layer_positions and parallel_layer_positions must be disjoint, " + f"got overlap {sorted(attn_positions & parallel_positions)} " + f"(attn={sorted(attn_positions)}, parallel={sorted(parallel_positions)})" + ) + if mamba2_positions & parallel_positions: + raise ValueError( + f"mamba2_layer_positions and parallel_layer_positions must be disjoint, " + f"got overlap {sorted(mamba2_positions & parallel_positions)} " + f"(mamba2={sorted(mamba2_positions)}, parallel={sorted(parallel_positions)})" + ) + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + if bigram_vocab_size > 0: + self.bigram_hash = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) + else: + self.bigram_hash = None + # Depth recurrence (0056): when num_unique_layers > 0, build K distinct + # blocks and disable U-Net skip-weights. forward() loops the K blocks L + # times for an effective depth of K*L. When num_unique_layers == 0, fall + # back to canonical U-Net path with num_layers distinct blocks. + self.recurrent = num_unique_layers > 0 + self.num_loops = num_loops + if self.recurrent: + effective_layers = num_unique_layers + self.num_encoder_layers = 0 + self.num_decoder_layers = num_unique_layers + self.num_skip_weights = 0 + # No skip_weights parameter in recurrent mode (no U-Net structure). + print( + f"[GPT] Depth-recurrence ENABLED: K={num_unique_layers} unique blocks, " + f"L={num_loops} loops, effective depth K*L={num_unique_layers * num_loops}. " + f"U-Net skip-weights DISABLED." + ) + else: + effective_layers = num_layers + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + use_attention=(i in attn_positions), + mamba2_mode=(i in mamba2_positions), + parallel_mode=(i in parallel_positions), + ) + for i in range(effective_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + # 0067 trigram side-memory: opt-in flag (default False). When False, + # forward path is identical to parent 0064 — and no trigram buffers exist. + # Buffers + flag are installed post-training in main() if enabled. + # 0068: _trigram_K (default 3) selects K=3 (trigram) vs K=4 (4-gram) + # path; lookup positions and key-encoding scale with K. + # 0069: _trigram_K_list (list[int]) supports multi-K (e.g. [3, 4] for + # combined). When len > 1, forward dispatches to multi-K blend. + # _trigram_blend_weights: [w_model, w_K1, w_K2, ...] (single-K = + # [alpha, 1-alpha], two-K = e.g. [0.7, 0.10, 0.20]). + self._use_trigram_blend = False + self._trigram_blend_alpha = 0.80 + self._trigram_vocab_size = vocab_size + self._trigram_K = 3 + self._trigram_K_list = [3] + self._trigram_blend_weights = [0.8, 0.2] + # 0074 per-context α (default off → byte-identical to 0069 multi-K). + self._per_context_alpha = False + self._alpha_max_default = 0.95 + # 0076 confidence-gated blend (default disabled → byte-identical to 0074). + self._conf_gate_threshold = -1e9 + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + if self.recurrent: + # Depth-recurrence: loop the K unique blocks L times. No U-Net skips. + for _loop_idx in range(self.num_loops): + for layer_idx in range(len(self.blocks)): + x = self.blocks[layer_idx](x, x0) + else: + skips: list[Tensor] = [] + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x_norm = self.final_norm(x) + # If trigram blending is OFF (default), preserve the original code path + # exactly to keep TRIGRAM_SIDE_MEMORY=0 byte-identical to parent 0064. + if not self._use_trigram_blend: + x_flat = x_norm.reshape(-1, x_norm.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + # Trigram-blended path. We need (B, L, V) logits to blend with trigram + # log-probs that depend on input positions, so do not flatten. + if self.tie_embeddings: + logits_proj = F.linear(x_norm, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_norm) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + log_softmax = F.log_softmax(logits.float(), dim=-1) # (B, L, V) nats + + # 0069 multi-K: when more than one K is configured, dispatch to the + # multi-K blend (3-way logsumexp of model + K=3 + K=4). Otherwise, + # use the original single-K path (byte-identical to 0067/0068). + if len(self._trigram_K_list) > 1: + raise NotImplementedError( + "trigram_side_memory removed in submission cleanup; this submission " + "uses TRIGRAM_SIDE_MEMORY=0 (default) so this code path is unreachable." + ) + packs_by_K: dict = {} + for K in self._trigram_K_list: + pfx = f"trigram{K}_" + pack_K = { + "keys": getattr(self, pfx + "keys"), + "offsets": getattr(self, pfx + "offsets"), + "next": getattr(self, pfx + "next"), + "log2p_quant": getattr(self, pfx + "log2p_quant"), + "log2p_scale": getattr(self, pfx + "log2p_scale"), + "log2p_offset": getattr(self, pfx + "log2p_offset"), + "log2_backoff_quant": getattr(self, pfx + "log2_backoff_quant"), + "log2_backoff_scale": getattr(self, pfx + "log2_backoff_scale"), + "log2_backoff_offset": getattr(self, pfx + "log2_backoff_offset"), + } + if self._per_context_alpha: + pack_K["alpha_quant"] = getattr(self, pfx + "alpha_quant") + pack_K["alpha_scale"] = getattr(self, pfx + "alpha_scale") + pack_K["alpha_offset"] = getattr(self, pfx + "alpha_offset") + packs_by_K[K] = pack_K + return trigram_blend_loss_multi_K( + log_softmax, + target_ids, + input_ids, + packs_by_K=packs_by_K, + bigram_log2p_quant=self.bigram_log2p_quant, + bigram_log2p_scales=self.bigram_log2p_scales, + bigram_log2p_offsets=self.bigram_log2p_offsets, + unigram_log2p=self.unigram_log2p, + blend_weights=self._trigram_blend_weights, + vocab_size=self._trigram_vocab_size, + K_order=self._trigram_K_list, + per_context_alpha=self._per_context_alpha, + alpha_max_default=self._alpha_max_default, + conf_gate_threshold=self._conf_gate_threshold, + ) + + raise NotImplementedError( + "trigram_side_memory removed in submission cleanup; this submission " + "uses TRIGRAM_SIDE_MEMORY=0 (default) so this code path is unreachable." + ) + return trigram_blend_loss( + log_softmax, + target_ids, + input_ids, + trigram_keys=self.trigram_keys, + trigram_offsets=self.trigram_offsets, + trigram_next=self.trigram_next, + trigram_log2p_quant=self.trigram_log2p_quant, + trigram_log2p_scale=self.trigram_log2p_scale, + trigram_log2p_offset=self.trigram_log2p_offset, + trigram_log2_backoff_quant=self.trigram_log2_backoff_quant, + trigram_log2_backoff_scale=self.trigram_log2_backoff_scale, + trigram_log2_backoff_offset=self.trigram_log2_backoff_offset, + bigram_log2p_quant=self.bigram_log2p_quant, + bigram_log2p_scales=self.bigram_log2p_scales, + bigram_log2p_offsets=self.bigram_log2p_offsets, + unigram_log2p=self.unigram_log2p, + blend_alpha=self._trigram_blend_alpha, + vocab_size=self._trigram_vocab_size, + K=self._trigram_K, + conf_gate_threshold=self._conf_gate_threshold, + ) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + + # ----------------------------- + # DISTRIBUTED + DEVICE SETUP + # ----------------------------- + + if torch.cuda.is_available(): + device_type = "cuda" + elif torch.backends.mps.is_available(): + device_type = "mps" + else: + device_type = "cpu" + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ and device_type == "cuda" + rank = int(os.environ.get("RANK", "0")) if distributed else 0 + world_size = int(os.environ.get("WORLD_SIZE", "1")) if distributed else 1 + local_rank = int(os.environ.get("LOCAL_RANK", "0")) if distributed else 0 + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if device_type == "cuda": + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + else: + device = torch.device(device_type) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + if device_type == "cuda": + autocast_ctx = torch.autocast("cuda", dtype=torch.bfloat16) + else: + autocast_ctx = contextlib.nullcontext() + + def sync() -> None: + if device_type == "cuda": + torch.cuda.synchronize() + elif device_type == "mps": + torch.mps.synchronize() + + # Fast math knobs (CUDA only) + if device_type == "cuda": + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + if device_type == "cuda": + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg, flush=True) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + if device_type == "cuda": + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_tokens_cap > 0: + cap = ((args.val_tokens_cap - 1) // args.train_seq_len) * args.train_seq_len + 1 + cap = min(cap, val_tokens.numel()) + val_tokens = val_tokens[:cap] + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + attn_layer_positions = { + int(x) for x in args.attn_layer_positions.split(",") if x.strip() + } + mamba2_layer_positions = { + int(x) for x in args.mamba2_layer_positions.split(",") if x.strip() + } + parallel_layer_positions = { + int(x) for x in args.parallel_layer_positions.split(",") if x.strip() + } + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_unique_layers=args.num_unique_layers, + num_loops=args.num_loops, + attn_layer_positions=attn_layer_positions, + mamba2_layer_positions=mamba2_layer_positions, + parallel_layer_positions=parallel_layer_positions, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if device_type == "cuda": + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + else: + compiled_model = base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + # 0028/0032: route 3D+ params (e.g. Mamba's depthwise conv1d.weight, shape + # (d_inner, 1, d_conv)) to scalar/Adam too. Pre-0028 the predicate was + # "ndim < 2 OR matched", which left 3D unmatched params in NEITHER bucket + # -- they would receive zero updates. Old behavior is preserved because + # no <0028 module had 3D params, and Mamba2Block's conv1d is matched by + # the 'conv1d' substring in CONTROL_TENSOR_NAME_PATTERNS regardless. + scalar_params = [ + p + for name, p in block_named_params + if p.ndim != 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if hasattr(base_model, "skip_weights") and base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + # 0018: include BigramHash params (which live outside base_model.blocks). + if getattr(base_model, "bigram_hash", None) is not None: + for name, p in base_model.bigram_hash.named_parameters(): + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + matrix_params.append(p) + else: + scalar_params.append(p) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=(device_type == "cuda"), + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=(device_type == "cuda"), + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=(device_type == "cuda"), + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"lr_warmup_steps:{args.lr_warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + # Linear LR warmup over the first lr_warmup_steps. `step` is 0-indexed + # at the moment of the optimizer call, so (step + 1) / N gives a ramp + # of 1/N, 2/N, ..., 1.0 across N steps. Default N=0 disables warmup. + if args.lr_warmup_steps > 0 and step < args.lr_warmup_steps: + return (step + 1) / args.lr_warmup_steps + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with autocast_ctx: + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + sync() + t0 = time.perf_counter() + + # 0116 EMA-of-weights shadow init. Default beta=0.0 -> ema_shadow stays None + # and all the EMA blocks below short-circuit (byte-identical to parent 0107). + ema_shadow: dict[str, torch.Tensor] | None = None + if args.ema_beta > 0.0: + ema_shadow = { + name: p.detach().float().clone() + for name, p in base_model.named_parameters() + if p.requires_grad + } + log0(f"[ema] enabled beta={args.ema_beta} warmup_offset={args.ema_warmup_offset} " + f"shadowing {len(ema_shadow)} parameters") + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + sync() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + # 0116 EMA: swap shadow into model parameters for the final eval. + # Pre-quant eval, quant export, and post-quant eval all see the shadow weights. + # We don't restore — training is done at this point. + if ema_shadow is not None and last_step: + log0(f"[ema] swapping shadow into model for final eval (beta={args.ema_beta})") + with torch.no_grad(): + for name, p in base_model.named_parameters(): + if name in ema_shadow: + p.data.copy_(ema_shadow[name].to(p.dtype)) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + autocast_ctx, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + sync() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with autocast_ctx: + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + # 0116 EMA-of-weights shadow update. At this point, this iteration's + # opt.step has just been applied; step is still the pre-increment count + # (number of completed steps BEFORE this iteration). After this update + # the count of completed steps is (step + 1), so we use that for the + # warmup gate. + if ema_shadow is not None: + if (step + 1) >= args.ema_warmup_offset: + # EMA update: w_ema = beta*w_ema + (1-beta)*w_current. fp32 throughout. + beta = args.ema_beta + one_minus_beta = 1.0 - beta + with torch.no_grad(): + for name, p in base_model.named_parameters(): + if name in ema_shadow: + ema_shadow[name].mul_(beta).add_(p.detach().float(), alpha=one_minus_beta) + else: + # Pre-warmup: shadow tracks current (no decay; avoids averaging random init). + with torch.no_grad(): + for name, p in base_model.named_parameters(): + if name in ema_shadow: + ema_shadow[name].copy_(p.detach().float()) + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + if device_type == "cuda": + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # 0067 TRIGRAM SIDE-MEMORY (post-training, pre-serialize) + # ----------------------------- + # Build a trigram + bigram backoff predictor from the first + # TRIGRAM_BUILD_TOKENS train tokens and install it as buffers on + # base_model. After this, model.forward auto-blends model log-probs with + # trigram log-probs in log-prob space (alpha = TRIGRAM_BLEND_ALPHA). + # Buffers are integer/fp32 -> they ride into state_dict and through the + # quantize_state_dict_int8 passthrough path (int tensors skip quant; small + # float tensors get the keep_float path). + if args.trigram_side_memory: + raise NotImplementedError( + "trigram_side_memory removed in submission cleanup; this submission " + "uses TRIGRAM_SIDE_MEMORY=0 (default) so this code path is unreachable." + ) + log0(f"[trigram] TRIGRAM_SIDE_MEMORY=1 — building pack post-training") + log0( + f"[trigram] TRIGRAM_K={args.trigram_K_list} " + f"TRIGRAM_TOP_N_CTX_K3={args.trigram_top_n_ctx_k3} " + f"TRIGRAM_TOP_N_CTX_K4={args.trigram_top_n_ctx_k4} " + f"TRIGRAM_TOP_N_CTX={args.trigram_top_n_ctx} " + f"TRIGRAM_BUILD_TOKENS={args.trigram_build_tokens:,} " + f"TRIGRAM_TOP_K={args.trigram_top_k} " + f"TRIGRAM_MIN_COUNT={args.trigram_min_count} " + f"TRIGRAM_BLEND_WEIGHTS={args.trigram_blend_weights}" + ) + log0( + f"[trigram] PER_CONTEXT_ALPHA={args.per_context_alpha} " + f"ALPHA_TAU={args.alpha_tau} ALPHA_THRESH={args.alpha_thresh} " + f"ALPHA_MIN={args.alpha_min} ALPHA_MAX={args.alpha_max}" + ) + log0(f"[trigram] CONF_GATE_THRESHOLD={args.conf_gate_threshold}") + # Read first N tokens from the train glob. TokenStream wraps shards; + # use it directly so we don't double-load. + train_stream = TokenStream(args.train_files) + train_tokens_t = train_stream.take(args.trigram_build_tokens) + train_tokens_np = train_tokens_t.to(torch.int64).cpu().numpy() + + K_list = list(args.trigram_K_list) + + if len(K_list) == 1: + # Single-K path — byte-identical to 0067 (K=3, top_n=0) and 0068 + # (K=4, top_n=200000) with original buffer names (no K-suffix). + K = K_list[0] + # For single-K, honor the original TRIGRAM_TOP_N_CTX env var (0067/ + # 0068's interface). If it's the default 0 and the user set a + # K-specific override, use the K-specific value. This keeps single-K + # backward-compat while letting K=3-only users opt into K3-pruning. + top_n = args.trigram_top_n_ctx + pack = build_trigram_pack( + train_tokens_np, + vocab_size=args.vocab_size, + top_k=args.trigram_top_k, + min_count=args.trigram_min_count, + K=K, + top_n_ctx=top_n, + ) + raw_bytes = pack_byte_size(pack) + log0(f"[trigram] pack raw bytes: {raw_bytes:,} ({raw_bytes / 1e6:.2f} MB)") + # Register buffers on base_model — original (un-suffixed) names. + for name, tensor in pack.items(): + base_model.register_buffer(name, tensor.to(device), persistent=True) + base_model._use_trigram_blend = True + base_model._trigram_blend_alpha = args.trigram_blend_alpha + base_model._trigram_vocab_size = args.vocab_size + base_model._trigram_K = K + base_model._trigram_K_list = [K] + base_model._trigram_blend_weights = list(args.trigram_blend_weights) + base_model._conf_gate_threshold = float(args.conf_gate_threshold) + log0( + f"[trigram] installed {len(pack)} buffers on base_model; " + f"_use_trigram_blend=True K={K} weights={args.trigram_blend_weights} " + f"conf_gate_threshold={args.conf_gate_threshold}" + ) + else: + # Multi-K path (0069). Build one pack per K. Use K-specific top-N. + # Buffers are stored with K-suffixed names: trigram3_keys, etc. + # The bigram + unigram fallback tables are SHARED across K's + # (built identically; we install them once from the first pack). + K_to_top_n = {3: args.trigram_top_n_ctx_k3, 4: args.trigram_top_n_ctx_k4} + packs_by_K: dict = {} + for K in K_list: + if K not in K_to_top_n: + raise ValueError( + f"Multi-K with K={K} not supported (only 3, 4 have top-N defaults)" + ) + top_n = K_to_top_n[K] + log0(f"[trigram] building K={K} pack (top_n_ctx={top_n})") + packs_by_K[K] = build_trigram_pack( + train_tokens_np, + vocab_size=args.vocab_size, + top_k=args.trigram_top_k, + min_count=args.trigram_min_count, + K=K, + top_n_ctx=top_n, + per_context_alpha=bool(args.per_context_alpha), + alpha_tau=args.alpha_tau, + alpha_thresh=args.alpha_thresh, + alpha_min=args.alpha_min, + alpha_max=args.alpha_max, + ) + # Install shared (un-suffixed) bigram + unigram from first pack. + shared_keys = ( + "bigram_log2p_quant", + "bigram_log2p_scales", + "bigram_log2p_offsets", + "unigram_log2p", + ) + first_K = K_list[0] + first_pack = packs_by_K[first_K] + for key in shared_keys: + base_model.register_buffer(key, first_pack[key].to(device), persistent=True) + # Install K-specific buffers under K-suffixed names. + kspec_keys = ( + "trigram_keys", + "trigram_offsets", + "trigram_next", + "trigram_log2p_quant", + "trigram_log2p_scale", + "trigram_log2p_offset", + "trigram_log2_backoff_quant", + "trigram_log2_backoff_scale", + "trigram_log2_backoff_offset", + ) + # 0074: per-context α buffers (only present when PER_CONTEXT_ALPHA=1). + kspec_alpha_keys = ( + "trigram_alpha_quant", + "trigram_alpha_scale", + "trigram_alpha_offset", + ) + total_raw = 0 + for K in K_list: + pack = packs_by_K[K] + for key in kspec_keys: + new_name = key.replace("trigram_", f"trigram{K}_", 1) + base_model.register_buffer(new_name, pack[key].to(device), persistent=True) + if args.per_context_alpha: + for key in kspec_alpha_keys: + if key not in pack: + raise RuntimeError( + f"PER_CONTEXT_ALPHA=1 but K={K} pack missing {key}" + ) + new_name = key.replace("trigram_", f"trigram{K}_", 1) + base_model.register_buffer(new_name, pack[key].to(device), persistent=True) + # Track raw size including this K's bigram/unigram only once + pack_raw = pack_byte_size(pack) + # Subtract shared-buffers' bytes for non-first K's (they're not + # duplicated in the model, but pack_byte_size counts them). + shared_bytes = sum(int(pack[k].numel()) * int(pack[k].element_size()) for k in shared_keys) + if K != first_K: + pack_raw -= shared_bytes + log0( + f"[trigram] K={K} pack raw bytes: {pack_raw:,} " + f"({pack_raw / 1e6:.2f} MB)" + ) + total_raw += pack_raw + log0(f"[trigram] total raw bytes (model side): {total_raw:,} ({total_raw / 1e6:.2f} MB)") + base_model._use_trigram_blend = True + base_model._trigram_blend_alpha = args.trigram_blend_alpha + base_model._trigram_vocab_size = args.vocab_size + base_model._trigram_K = K_list[0] + base_model._trigram_K_list = list(K_list) + base_model._trigram_blend_weights = list(args.trigram_blend_weights) + base_model._per_context_alpha = bool(args.per_context_alpha) + base_model._alpha_max_default = float(args.alpha_max) + base_model._conf_gate_threshold = float(args.conf_gate_threshold) + log0( + f"[trigram] multi-K installed: K_list={K_list} " + f"weights={args.trigram_blend_weights} " + f"_use_trigram_blend=True _per_context_alpha={base_model._per_context_alpha} " + f"_conf_gate_threshold={base_model._conf_gate_threshold}" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # 0086 v2: build the set of weight names belonging to BitLinear modules so + # quantize_state_dict_int8 routes them through pack_ternary instead of int8. + bitlinear_weight_names = { + f"{mod_name}.weight" for mod_name, mod in base_model.named_modules() + if isinstance(mod, BitLinear) + } + log0(f"v2 packed-ternary export: {len(bitlinear_weight_names)} BitLinear weights → 2-bit packed") + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict(), bitlinear_weight_names) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = brotli.compress(quant_raw, quality=11) + quant_blob_zlib = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + log0( + f"brotli_vs_zlib brotli_bytes:{len(quant_blob)} zlib_bytes:{len(quant_blob_zlib)} " + f"savings_bytes:{len(quant_blob_zlib) - len(quant_blob)} " + f"ratio:{len(quant_blob) / max(len(quant_blob_zlib), 1):.4f}" + ) + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(brotli.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + sync() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + autocast_ctx, + ) + sync() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_non_record_16mb/2026-04-30_KillMamba2_TriParallel_n7_Ternary_EMA_4xH200_1hr/train_seed1337.log b/records/track_non_record_16mb/2026-04-30_KillMamba2_TriParallel_n7_Ternary_EMA_4xH200_1hr/train_seed1337.log new file mode 100644 index 0000000000..015bb21e08 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_KillMamba2_TriParallel_n7_Ternary_EMA_4xH200_1hr/train_seed1337.log @@ -0,0 +1,54 @@ +# Partial training log for seed 1337 +# (Pod was stopped before the full run.log could be synced from /workspace; the lines below were captured via the monitor stream during training and are sufficient to verify the headline numbers.) + +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=../../data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=../../data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +[GPT] Depth-recurrence ENABLED: K=7 unique blocks, L=3 loops, effective depth K*L=21. U-Net skip-weights DISABLED. +model_params:61657752 +world_size:4 grad_accum_steps:2 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.045 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:0 lr_warmup_steps:30 max_wallclock_seconds:3600.000 +seed:1337 +[ema] enabled beta=0.999 warmup_offset=30 shadowing 148 parameters + +# --- training (every 100 steps) --- +step:100/20000 train_loss:3.3586 train_time:187400ms step_avg:1874.00ms +step:200/20000 train_loss:2.8262 train_time:275752ms step_avg:1378.76ms +step:300/20000 train_loss:2.5926 train_time:349866ms step_avg:1166.22ms +step:400/20000 train_loss:2.3766 train_time:438347ms step_avg:1095.87ms +step:500/20000 train_loss:2.4614 train_time:512539ms step_avg:1025.08ms +step:600/20000 train_loss:2.5687 train_time:600424ms step_avg:1000.71ms +step:700/20000 train_loss:2.4633 train_time:674462ms step_avg:963.52ms +step:800/20000 train_loss:2.3569 train_time:751007ms step_avg:938.76ms +step:900/20000 train_loss:2.3796 train_time:825810ms step_avg:917.57ms +step:1000/20000 train_loss:2.4222 train_time:909235ms step_avg:909.24ms +step:1500/20000 train_loss:2.3330 train_time:1289594ms step_avg:859.73ms +step:2000/20000 train_loss:2.2737 train_time:1702121ms step_avg:851.06ms +step:2500/20000 train_loss:2.2367 train_time:2095799ms step_avg:838.32ms +step:3000/20000 train_loss:2.2197 train_time:2489456ms step_avg:829.82ms +step:3500/20000 train_loss:2.2448 train_time:2903686ms step_avg:829.62ms +step:3900/20000 train_loss:2.0272 train_time:3221626ms step_avg:826.06ms +step:4000/20000 train_loss:2.1408 train_time:3295749ms step_avg:823.94ms +step:4380/20000 train_loss:2.0161 train_time:3600100ms step_avg:821.94ms + +# --- eval (post-EMA-shadow swap) --- +[ema] swapping shadow into model for final eval (beta=0.999) +step:4380/20000 val_loss:2.1921 val_bpb:1.2983 train_time:3600102ms step_avg:821.94ms +stopping_early: wallclock_cap train_time:3600102ms step:4380/20000 +peak memory allocated: 114125 MiB reserved: 115026 MiB + +# --- quantize export --- +Serialized model: 123491967 bytes +Code size: 104676 bytes +Total submission size: 123596643 bytes +v2 packed-ternary export: 63 BitLinear weights → 2-bit packed +Serialized model int8+zlib: 11969746 bytes (payload:16033632 raw_torch:16083093 payload_ratio:7.70x) +Total submission size int8+zlib: 12074422 bytes +brotli_vs_zlib brotli_bytes:11969746 zlib_bytes:12156986 savings_bytes:187240 ratio:0.9846 + +# --- post-quant eval (int8 + brotli round-trip) --- +final_int8_zlib_roundtrip val_loss:2.1957 val_bpb:1.3004 eval_time:357422ms +final_int8_zlib_roundtrip_exact val_loss:2.19567479 val_bpb:1.30040229