diff --git a/records/track_non_record_16mb/2026-03-26_HigherRankHeads_11L_Study/README.md b/records/track_non_record_16mb/2026-03-26_HigherRankHeads_11L_Study/README.md new file mode 100644 index 0000000000..9fd58d2166 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_HigherRankHeads_11L_Study/README.md @@ -0,0 +1,93 @@ +# Approach Study: Higher-Rank Output Heads on a Frontier 11L Baseline + +This folder documents a non-record study of higher-rank output heads on top of a fixed frontier-aligned 11-layer baseline. + +Research question: +- On a strong fixed 11L control, can higher-rank output heads outperform the standard tied softmax head under the 10-minute training budget? + +## Summary + +- Control: the standard tied head reached `1.1734 val_bpb` in `600s`. +- Result: every tested higher-rank head variant underperformed the control, often by a large margin. +- Artifact impact: mixture heads increased artifact size, while the simplex head reduced artifact size substantially but collapsed score. +- Main finding: on this frontier-aligned small-budget regime, the standard tied head remained the strongest option. Extra output-head structure behaved as an optimization burden rather than a compression win. + +## Fixed Baseline + +All runs used the same training and evaluation stack: +- 11 layers, `d_model=512`, `8` query heads, `4` KV heads +- `MLP_MULT=3` +- EMA from init (`alpha=0.997`) +- XSA on the last `4` layers +- SmearGate enabled +- BigramHash enabled (`2048` buckets, `128` dim) +- partial RoPE (`16` rotary dims) with NTK-aware scaling +- LN Scale enabled +- VE128 enabled on layers `9,10` +- Late QAT enabled at `lr_scale < 0.15` +- `seq2048`, `786432` train tokens/step +- sliding evaluation (`stride=64`) +- `8xH100`, `600s` wallclock cap +- Hopper FA3, compiled training, and the real quantization/artifact path + +Only one family parameter changed across the study: +- output-head type and its local bottleneck settings + +## Variants + +Tested head family: +- `H0`: standard tied head +- `H1`: factorized head, rank `64` +- `H2`: factorized head, rank `128` +- `H3`: mixture-softmax, `K=2`, rank `64` +- `H4`: mixture-softmax, `K=4`, rank `64` +- `H5`: mixture-softmax, `K=4`, rank `128` +- `H6`: simplex head, bottleneck `128` + +## Results + +| Run | Head Variant | `val_bpb` | Δ vs `H0` | Steps | Artifact bytes | Notes | +|-----|--------------|----------:|----------:|------:|---------------:|-------| +| `H0` | standard tied head | `1.1734` | `0.0000` | `4415` | `16826913` | control | +| `H1` | factorized `r=64` | `2.4396` | `+1.2662` | `4451` | `16729834` | severe degradation | +| `H2` | factorized `r=128` | `1.9227` | `+0.7494` | `4425` | `16918260` | still far worse than control | +| `H3` | MoS `K=2`, `r=64` | `2.6167` | `+1.4434` | `4428` | `16565348` | severe degradation | +| `H4` | MoS `K=4`, `r=64` | `2.7112` | `+1.5379` | `4149` | `17172588` | worst mixture result | +| `H5` | MoS `K=4`, `r=128` | `2.0898` | `+0.9165` | `4160` | `17943057` | worse score and larger artifact | +| `H6` | simplex `128` | `4.1069` | `+2.9336` | `4241` | `10950817` | smallest artifact, unusable score | + +The result is unambiguous: none of the tested higher-rank heads improved the frontier-aligned control, and several failed catastrophically. + +## Interpretation + +This study does not show that higher-rank output heads are useless in general. It shows something narrower and still useful: +- on this specific frontier-aligned 11L budgeted regime, +- with a strong tied-head baseline already in place, +- extra output-head structure was harder to optimize than the standard head, +- and the added expressivity did not translate into better compression. + +The negative result is still useful for future work: +- if this family is revisited, it likely needs a different training regime rather than a direct swap on top of a tuned small-budget control +- the simplex head is notable as an artifact-size reduction idea, but not as a quality-preserving one in this form +- the mixture-head variants were the clearest failure mode: more parameters in the output head did not buy better compression here + +## Why There Is No Separate Confirmatory Matrix + +Unlike the semantic-tube study, this family sweep was already run on the intended fast path: +- compiled training +- Hopper FA3 +- full `80` training shards +- sliding evaluation +- real quantization and artifact generation + +So the family sweep itself already serves as the authoritative result set for this study. + +## Included Files + +Included here: +- `family_heads.jsonl`: raw study results +- `family_heads_review.md`: compact study summary +- `train_gpt.py`: self-contained study-local training script +- `install_flash_attn_hopper.sh`: Hopper-only FA3 installer used by the study runner +- `run_higher_rank_heads_study.sh`: self-contained family runner +- `REPRODUCE.md`: reproduction commands diff --git a/records/track_non_record_16mb/2026-03-26_HigherRankHeads_11L_Study/REPRODUCE.md b/records/track_non_record_16mb/2026-03-26_HigherRankHeads_11L_Study/REPRODUCE.md new file mode 100644 index 0000000000..b7af941e26 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_HigherRankHeads_11L_Study/REPRODUCE.md @@ -0,0 +1,27 @@ +# Reproduction + +The study runner is self-contained: +- it uses the `train_gpt.py` included in this folder +- it ensures the full `fineweb10B_sp1024` dataset is present (`80` training shards) +- it installs or reuses a Hopper-only FA3 wheel before training +- it runs the full 7-variant family on the intended fast path +- it copies the fresh per-run console logs for the rerun into `logs/` inside this folder + +## Run The Full Family + +```bash +bash records/track_non_record_16mb/2026-03-26_HigherRankHeads_11L_Study/run_higher_rank_heads_study.sh +``` + +This reruns the full family: +- `H0`: standard tied head +- `H1`: factorized `r=64` +- `H2`: factorized `r=128` +- `H3`: mixture-softmax `K=2`, `r=64` +- `H4`: mixture-softmax `K=4`, `r=64` +- `H5`: mixture-softmax `K=4`, `r=128` +- `H6`: simplex `128` + +Expected budget: +- `7` runs +- about `70` minutes total on `8xH100` diff --git a/records/track_non_record_16mb/2026-03-26_HigherRankHeads_11L_Study/family_heads.jsonl b/records/track_non_record_16mb/2026-03-26_HigherRankHeads_11L_Study/family_heads.jsonl new file mode 100644 index 0000000000..cac6c0fbcc --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_HigherRankHeads_11L_Study/family_heads.jsonl @@ -0,0 +1,7 @@ +{"id": 1, "timestamp": "2026-03-26T18:28:50Z", "approach": "head_family", "variation": "control_standard", "subvariation": "seed42", "commit": "76ceb8e", "env_vars": {"NUM_LAYERS": "11", "MODEL_DIM": "512", "NUM_HEADS": "8", "NUM_KV_HEADS": "4", "MLP_MULT": "3", "MATRIX_LR": "0.025", "SCALAR_LR": "0.025", "MUON_MOMENTUM": "0.99", "SKIP_QUANT": "0", "MAX_WALLCLOCK_SECONDS": "600", "LAMBDA_SPECTRAL": "0.0", "LAMBDA_TUBE": "0.0", "RANK_DIM": "0", "ATTENTION_TYPE": "standard", "EVAL_STRIDE": "64", "WARMDOWN_ITERS": "3500", "COMPRESS_ALGO": "zstd", "QUANT_BITS_MIDDLE": "6", "NUM_GPUS": "8", "USE_QAT": "0", "QAT_BITS": "6", "TRAIN_SEQ_LEN": "2048", "SMEAR_GATE": "1", "BIGRAM_HASH": "1", "BIGRAM_VOCAB_SIZE": "2048", "BIGRAM_DIM": "128", "MUON_WEIGHT_DECAY": "0.04", "ADAM_WD": "0.04", "INIT_TYPE": "ortho", "USE_EMA": "1", "EMA_ALPHA": "0.997", "USE_NTK_ROPE": "1", "USE_FA3": "1", "FLASH_ATTN_BACKEND": "fa3", "FLASH_ATTN_STRICT": "1", "FLASH_ATTN_ARCH_LIST": "9.0", "QUANT_BITS_MLP": "6", "TIED_EMBED_LR": "0.035", "MUON_MOMENTUM_WARMUP_START": "0.92", "MUON_MOMENTUM_WARMUP_STEPS": "1500", "GRAD_CLIP_NORM": "0.3", "USE_TTT": "0", "XSA_LAST_N": "4", "ROPE_DIMS": "16", "LN_SCALE": "1", "LATE_QAT": "1", "QAT_THRESHOLD": "0.15", "TRAIN_BATCH_TOKENS": "786432", "EMA_FROM_INIT": "1", "GRAD_GUIDED_QUANT": "0", "BACKOUT_CONNECTION": "0", "SPECTRAL_WARMDOWN": "0", "VALUE_EMBED": "1", "VE_DIM": "128", "VE_LAYERS": "9,10", "TOKEN_SALIENCY": "0", "QUANT_ERROR_DIFFUSION": "0", "QUANT_HADAMARD": "0", "LAMBDA_VICREG": "0.0", "HYPERGRAPH_LIFT": "0", "HYPERGRAPH_SCALES": "2,4,8", "HEAD_TYPE": "standard"}, "fn_injection": null, "val_bpb": 1.1733618, "val_loss": 1.97221607, "peak_memory_mb": 26709, "training_time_ms": 600059, "num_params": 26993766, "num_steps": 4415, "artifact_size_bytes": 16826913, "artifact_est_bytes": 0, "status": "pending", "reward": 0.5, "description": "H0: Control baseline (standard head)", "tags": []} +{"id": 2, "timestamp": "2026-03-26T18:41:35Z", "approach": "head_family", "variation": "factorized_r64", "subvariation": "seed42", "commit": "76ceb8e", "env_vars": {"NUM_LAYERS": "11", "MODEL_DIM": "512", "NUM_HEADS": "8", "NUM_KV_HEADS": "4", "MLP_MULT": "3", "MATRIX_LR": "0.025", "SCALAR_LR": "0.025", "MUON_MOMENTUM": "0.99", "SKIP_QUANT": "0", "MAX_WALLCLOCK_SECONDS": "600", "LAMBDA_SPECTRAL": "0.0", "LAMBDA_TUBE": "0.0", "RANK_DIM": "64", "ATTENTION_TYPE": "standard", "EVAL_STRIDE": "64", "WARMDOWN_ITERS": "3500", "COMPRESS_ALGO": "zstd", "QUANT_BITS_MIDDLE": "6", "NUM_GPUS": "8", "USE_QAT": "0", "QAT_BITS": "6", "TRAIN_SEQ_LEN": "2048", "SMEAR_GATE": "1", "BIGRAM_HASH": "1", "BIGRAM_VOCAB_SIZE": "2048", "BIGRAM_DIM": "128", "MUON_WEIGHT_DECAY": "0.04", "ADAM_WD": "0.04", "INIT_TYPE": "ortho", "USE_EMA": "1", "EMA_ALPHA": "0.997", "USE_NTK_ROPE": "1", "USE_FA3": "1", "FLASH_ATTN_BACKEND": "fa3", "FLASH_ATTN_STRICT": "1", "FLASH_ATTN_ARCH_LIST": "9.0", "QUANT_BITS_MLP": "6", "TIED_EMBED_LR": "0.035", "MUON_MOMENTUM_WARMUP_START": "0.92", "MUON_MOMENTUM_WARMUP_STEPS": "1500", "GRAD_CLIP_NORM": "0.3", "USE_TTT": "0", "XSA_LAST_N": "4", "ROPE_DIMS": "16", "LN_SCALE": "1", "LATE_QAT": "1", "QAT_THRESHOLD": "0.15", "TRAIN_BATCH_TOKENS": "786432", "EMA_FROM_INIT": "1", "GRAD_GUIDED_QUANT": "0", "BACKOUT_CONNECTION": "0", "SPECTRAL_WARMDOWN": "0", "VALUE_EMBED": "1", "VE_DIM": "128", "VE_LAYERS": "9,10", "TOKEN_SALIENCY": "0", "QUANT_ERROR_DIFFUSION": "0", "QUANT_HADAMARD": "0", "LAMBDA_VICREG": "0.0", "HYPERGRAPH_LIFT": "0", "HYPERGRAPH_SCALES": "2,4,8", "HEAD_TYPE": "standard"}, "fn_injection": null, "val_bpb": 2.43959467, "val_loss": 4.10053219, "peak_memory_mb": 26725, "training_time_ms": 604650, "num_params": 27092070, "num_steps": 4451, "artifact_size_bytes": 16729834, "artifact_est_bytes": 0, "status": "pending", "reward": -0.5, "description": "H1: Factorized head rank=64", "tags": []} +{"id": 3, "timestamp": "2026-03-26T18:54:01Z", "approach": "head_family", "variation": "factorized_r128", "subvariation": "seed42", "commit": "76ceb8e", "env_vars": {"NUM_LAYERS": "11", "MODEL_DIM": "512", "NUM_HEADS": "8", "NUM_KV_HEADS": "4", "MLP_MULT": "3", "MATRIX_LR": "0.025", "SCALAR_LR": "0.025", "MUON_MOMENTUM": "0.99", "SKIP_QUANT": "0", "MAX_WALLCLOCK_SECONDS": "600", "LAMBDA_SPECTRAL": "0.0", "LAMBDA_TUBE": "0.0", "RANK_DIM": "128", "ATTENTION_TYPE": "standard", "EVAL_STRIDE": "64", "WARMDOWN_ITERS": "3500", "COMPRESS_ALGO": "zstd", "QUANT_BITS_MIDDLE": "6", "NUM_GPUS": "8", "USE_QAT": "0", "QAT_BITS": "6", "TRAIN_SEQ_LEN": "2048", "SMEAR_GATE": "1", "BIGRAM_HASH": "1", "BIGRAM_VOCAB_SIZE": "2048", "BIGRAM_DIM": "128", "MUON_WEIGHT_DECAY": "0.04", "ADAM_WD": "0.04", "INIT_TYPE": "ortho", "USE_EMA": "1", "EMA_ALPHA": "0.997", "USE_NTK_ROPE": "1", "USE_FA3": "1", "FLASH_ATTN_BACKEND": "fa3", "FLASH_ATTN_STRICT": "1", "FLASH_ATTN_ARCH_LIST": "9.0", "QUANT_BITS_MLP": "6", "TIED_EMBED_LR": "0.035", "MUON_MOMENTUM_WARMUP_START": "0.92", "MUON_MOMENTUM_WARMUP_STEPS": "1500", "GRAD_CLIP_NORM": "0.3", "USE_TTT": "0", "XSA_LAST_N": "4", "ROPE_DIMS": "16", "LN_SCALE": "1", "LATE_QAT": "1", "QAT_THRESHOLD": "0.15", "TRAIN_BATCH_TOKENS": "786432", "EMA_FROM_INIT": "1", "GRAD_GUIDED_QUANT": "0", "BACKOUT_CONNECTION": "0", "SPECTRAL_WARMDOWN": "0", "VALUE_EMBED": "1", "VE_DIM": "128", "VE_LAYERS": "9,10", "TOKEN_SALIENCY": "0", "QUANT_ERROR_DIFFUSION": "0", "QUANT_HADAMARD": "0", "LAMBDA_VICREG": "0.0", "HYPERGRAPH_LIFT": "0", "HYPERGRAPH_SCALES": "2,4,8", "HEAD_TYPE": "standard"}, "fn_injection": null, "val_bpb": 1.92273838, "val_loss": 3.23178711, "peak_memory_mb": 26736, "training_time_ms": 599968, "num_params": 27190374, "num_steps": 4425, "artifact_size_bytes": 16918260, "artifact_est_bytes": 0, "status": "pending", "reward": -0.5, "description": "H2: Factorized head rank=128", "tags": []} +{"id": 4, "timestamp": "2026-03-26T19:06:33Z", "approach": "head_family", "variation": "mos_k2_r64", "subvariation": "seed42", "commit": "76ceb8e", "env_vars": {"NUM_LAYERS": "11", "MODEL_DIM": "512", "NUM_HEADS": "8", "NUM_KV_HEADS": "4", "MLP_MULT": "3", "MATRIX_LR": "0.025", "SCALAR_LR": "0.025", "MUON_MOMENTUM": "0.99", "SKIP_QUANT": "0", "MAX_WALLCLOCK_SECONDS": "600", "LAMBDA_SPECTRAL": "0.0", "LAMBDA_TUBE": "0.0", "RANK_DIM": "0", "ATTENTION_TYPE": "standard", "EVAL_STRIDE": "64", "WARMDOWN_ITERS": "3500", "COMPRESS_ALGO": "zstd", "QUANT_BITS_MIDDLE": "6", "NUM_GPUS": "8", "USE_QAT": "0", "QAT_BITS": "6", "TRAIN_SEQ_LEN": "2048", "SMEAR_GATE": "1", "BIGRAM_HASH": "1", "BIGRAM_VOCAB_SIZE": "2048", "BIGRAM_DIM": "128", "MUON_WEIGHT_DECAY": "0.04", "ADAM_WD": "0.04", "INIT_TYPE": "ortho", "USE_EMA": "1", "EMA_ALPHA": "0.997", "USE_NTK_ROPE": "1", "USE_FA3": "1", "FLASH_ATTN_BACKEND": "fa3", "FLASH_ATTN_STRICT": "1", "FLASH_ATTN_ARCH_LIST": "9.0", "QUANT_BITS_MLP": "6", "TIED_EMBED_LR": "0.035", "MUON_MOMENTUM_WARMUP_START": "0.92", "MUON_MOMENTUM_WARMUP_STEPS": "1500", "GRAD_CLIP_NORM": "0.3", "USE_TTT": "0", "XSA_LAST_N": "4", "ROPE_DIMS": "16", "LN_SCALE": "1", "LATE_QAT": "1", "QAT_THRESHOLD": "0.15", "TRAIN_BATCH_TOKENS": "786432", "EMA_FROM_INIT": "1", "GRAD_GUIDED_QUANT": "0", "BACKOUT_CONNECTION": "0", "SPECTRAL_WARMDOWN": "0", "VALUE_EMBED": "1", "VE_DIM": "128", "VE_LAYERS": "9,10", "TOKEN_SALIENCY": "0", "QUANT_ERROR_DIFFUSION": "0", "QUANT_HADAMARD": "0", "LAMBDA_VICREG": "0.0", "HYPERGRAPH_LIFT": "0", "HYPERGRAPH_SCALES": "2,4,8", "HEAD_TYPE": "mixture_softmax", "MIXTURE_SOFTMAX_K": "2", "MIXTURE_RANK_DIM": "64"}, "fn_injection": null, "val_bpb": 2.61671791, "val_loss": 4.39824539, "peak_memory_mb": 27889, "training_time_ms": 600007, "num_params": 27191398, "num_steps": 4428, "artifact_size_bytes": 16565348, "artifact_est_bytes": 0, "status": "pending", "reward": -0.5, "description": "H3: Mixture softmax K=2 rank=64", "tags": []} +{"id": 5, "timestamp": "2026-03-26T19:19:17Z", "approach": "head_family", "variation": "mos_k4_r64", "subvariation": "seed42", "commit": "76ceb8e", "env_vars": {"NUM_LAYERS": "11", "MODEL_DIM": "512", "NUM_HEADS": "8", "NUM_KV_HEADS": "4", "MLP_MULT": "3", "MATRIX_LR": "0.025", "SCALAR_LR": "0.025", "MUON_MOMENTUM": "0.99", "SKIP_QUANT": "0", "MAX_WALLCLOCK_SECONDS": "600", "LAMBDA_SPECTRAL": "0.0", "LAMBDA_TUBE": "0.0", "RANK_DIM": "0", "ATTENTION_TYPE": "standard", "EVAL_STRIDE": "64", "WARMDOWN_ITERS": "3500", "COMPRESS_ALGO": "zstd", "QUANT_BITS_MIDDLE": "6", "NUM_GPUS": "8", "USE_QAT": "0", "QAT_BITS": "6", "TRAIN_SEQ_LEN": "2048", "SMEAR_GATE": "1", "BIGRAM_HASH": "1", "BIGRAM_VOCAB_SIZE": "2048", "BIGRAM_DIM": "128", "MUON_WEIGHT_DECAY": "0.04", "ADAM_WD": "0.04", "INIT_TYPE": "ortho", "USE_EMA": "1", "EMA_ALPHA": "0.997", "USE_NTK_ROPE": "1", "USE_FA3": "1", "FLASH_ATTN_BACKEND": "fa3", "FLASH_ATTN_STRICT": "1", "FLASH_ATTN_ARCH_LIST": "9.0", "QUANT_BITS_MLP": "6", "TIED_EMBED_LR": "0.035", "MUON_MOMENTUM_WARMUP_START": "0.92", "MUON_MOMENTUM_WARMUP_STEPS": "1500", "GRAD_CLIP_NORM": "0.3", "USE_TTT": "0", "XSA_LAST_N": "4", "ROPE_DIMS": "16", "LN_SCALE": "1", "LATE_QAT": "1", "QAT_THRESHOLD": "0.15", "TRAIN_BATCH_TOKENS": "786432", "EMA_FROM_INIT": "1", "GRAD_GUIDED_QUANT": "0", "BACKOUT_CONNECTION": "0", "SPECTRAL_WARMDOWN": "0", "VALUE_EMBED": "1", "VE_DIM": "128", "VE_LAYERS": "9,10", "TOKEN_SALIENCY": "0", "QUANT_ERROR_DIFFUSION": "0", "QUANT_HADAMARD": "0", "LAMBDA_VICREG": "0.0", "HYPERGRAPH_LIFT": "0", "HYPERGRAPH_SCALES": "2,4,8", "HEAD_TYPE": "mixture_softmax", "MIXTURE_SOFTMAX_K": "4", "MIXTURE_RANK_DIM": "64"}, "fn_injection": null, "val_bpb": 2.7112436, "val_loss": 4.55712656, "peak_memory_mb": 28686, "training_time_ms": 599989, "num_params": 27389030, "num_steps": 4149, "artifact_size_bytes": 17172588, "artifact_est_bytes": 0, "status": "pending", "reward": -0.5, "description": "H4: Mixture softmax K=4 rank=64", "tags": []} +{"id": 6, "timestamp": "2026-03-26T19:31:47Z", "approach": "head_family", "variation": "mos_k4_r128", "subvariation": "seed42", "commit": "76ceb8e", "env_vars": {"NUM_LAYERS": "11", "MODEL_DIM": "512", "NUM_HEADS": "8", "NUM_KV_HEADS": "4", "MLP_MULT": "3", "MATRIX_LR": "0.025", "SCALAR_LR": "0.025", "MUON_MOMENTUM": "0.99", "SKIP_QUANT": "0", "MAX_WALLCLOCK_SECONDS": "600", "LAMBDA_SPECTRAL": "0.0", "LAMBDA_TUBE": "0.0", "RANK_DIM": "0", "ATTENTION_TYPE": "standard", "EVAL_STRIDE": "64", "WARMDOWN_ITERS": "3500", "COMPRESS_ALGO": "zstd", "QUANT_BITS_MIDDLE": "6", "NUM_GPUS": "8", "USE_QAT": "0", "QAT_BITS": "6", "TRAIN_SEQ_LEN": "2048", "SMEAR_GATE": "1", "BIGRAM_HASH": "1", "BIGRAM_VOCAB_SIZE": "2048", "BIGRAM_DIM": "128", "MUON_WEIGHT_DECAY": "0.04", "ADAM_WD": "0.04", "INIT_TYPE": "ortho", "USE_EMA": "1", "EMA_ALPHA": "0.997", "USE_NTK_ROPE": "1", "USE_FA3": "1", "FLASH_ATTN_BACKEND": "fa3", "FLASH_ATTN_STRICT": "1", "FLASH_ATTN_ARCH_LIST": "9.0", "QUANT_BITS_MLP": "6", "TIED_EMBED_LR": "0.035", "MUON_MOMENTUM_WARMUP_START": "0.92", "MUON_MOMENTUM_WARMUP_STEPS": "1500", "GRAD_CLIP_NORM": "0.3", "USE_TTT": "0", "XSA_LAST_N": "4", "ROPE_DIMS": "16", "LN_SCALE": "1", "LATE_QAT": "1", "QAT_THRESHOLD": "0.15", "TRAIN_BATCH_TOKENS": "786432", "EMA_FROM_INIT": "1", "GRAD_GUIDED_QUANT": "0", "BACKOUT_CONNECTION": "0", "SPECTRAL_WARMDOWN": "0", "VALUE_EMBED": "1", "VE_DIM": "128", "VE_LAYERS": "9,10", "TOKEN_SALIENCY": "0", "QUANT_ERROR_DIFFUSION": "0", "QUANT_HADAMARD": "0", "LAMBDA_VICREG": "0.0", "HYPERGRAPH_LIFT": "0", "HYPERGRAPH_SCALES": "2,4,8", "HEAD_TYPE": "mixture_softmax", "MIXTURE_SOFTMAX_K": "4", "MIXTURE_RANK_DIM": "128"}, "fn_injection": null, "val_bpb": 2.08984766, "val_loss": 3.51266861, "peak_memory_mb": 28738, "training_time_ms": 600000, "num_params": 27782246, "num_steps": 4160, "artifact_size_bytes": 17943057, "artifact_est_bytes": 0, "status": "pending", "reward": -0.5, "description": "H5: Mixture softmax K=4 rank=128", "tags": []} +{"id": 7, "timestamp": "2026-03-26T19:44:13Z", "approach": "head_family", "variation": "simplex_128", "subvariation": "seed42", "commit": "76ceb8e", "env_vars": {"NUM_LAYERS": "11", "MODEL_DIM": "512", "NUM_HEADS": "8", "NUM_KV_HEADS": "4", "MLP_MULT": "3", "MATRIX_LR": "0.025", "SCALAR_LR": "0.025", "MUON_MOMENTUM": "0.99", "SKIP_QUANT": "0", "MAX_WALLCLOCK_SECONDS": "600", "LAMBDA_SPECTRAL": "0.0", "LAMBDA_TUBE": "0.0", "RANK_DIM": "0", "ATTENTION_TYPE": "standard", "EVAL_STRIDE": "64", "WARMDOWN_ITERS": "3500", "COMPRESS_ALGO": "zstd", "QUANT_BITS_MIDDLE": "6", "NUM_GPUS": "8", "USE_QAT": "0", "QAT_BITS": "6", "TRAIN_SEQ_LEN": "2048", "SMEAR_GATE": "1", "BIGRAM_HASH": "1", "BIGRAM_VOCAB_SIZE": "2048", "BIGRAM_DIM": "128", "MUON_WEIGHT_DECAY": "0.04", "ADAM_WD": "0.04", "INIT_TYPE": "ortho", "USE_EMA": "1", "EMA_ALPHA": "0.997", "USE_NTK_ROPE": "1", "USE_FA3": "1", "FLASH_ATTN_BACKEND": "fa3", "FLASH_ATTN_STRICT": "1", "FLASH_ATTN_ARCH_LIST": "9.0", "QUANT_BITS_MLP": "6", "TIED_EMBED_LR": "0.035", "MUON_MOMENTUM_WARMUP_START": "0.92", "MUON_MOMENTUM_WARMUP_STEPS": "1500", "GRAD_CLIP_NORM": "0.3", "USE_TTT": "0", "XSA_LAST_N": "4", "ROPE_DIMS": "16", "LN_SCALE": "1", "LATE_QAT": "1", "QAT_THRESHOLD": "0.15", "TRAIN_BATCH_TOKENS": "786432", "EMA_FROM_INIT": "1", "GRAD_GUIDED_QUANT": "0", "BACKOUT_CONNECTION": "0", "SPECTRAL_WARMDOWN": "0", "VALUE_EMBED": "1", "VE_DIM": "128", "VE_LAYERS": "9,10", "TOKEN_SALIENCY": "0", "QUANT_ERROR_DIFFUSION": "0", "QUANT_HADAMARD": "0", "LAMBDA_VICREG": "0.0", "HYPERGRAPH_LIFT": "0", "HYPERGRAPH_SCALES": "2,4,8", "HEAD_TYPE": "simplex", "SIMPLEX_DIM": "128"}, "fn_injection": null, "val_bpb": 4.10691873, "val_loss": 6.90301248, "peak_memory_mb": 26760, "training_time_ms": 599973, "num_params": 27190374, "num_steps": 4241, "artifact_size_bytes": 10950817, "artifact_est_bytes": 0, "status": "pending", "reward": -0.5, "description": "H6: Simplex head dim=128", "tags": []} diff --git a/records/track_non_record_16mb/2026-03-26_HigherRankHeads_11L_Study/family_heads_review.md b/records/track_non_record_16mb/2026-03-26_HigherRankHeads_11L_Study/family_heads_review.md new file mode 100644 index 0000000000..c3087712c9 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_HigherRankHeads_11L_Study/family_heads_review.md @@ -0,0 +1,26 @@ +# Higher-Rank Output Heads Family Study + +- Family: `higher_rank_heads` +- Source JSONL: `family_heads.jsonl` +- Runs: `7` +- Best result: `control_standard` at `1.1734 val_bpb` + +## Fixed Baseline + +11L/512d fixed backbone with EMA, XSA4, SmearGate, BigramHash, partial RoPE, LN Scale, VE128 on late layers, Late QAT, `seq2048`, Hopper FA3, compiled training, sliding evaluation, and the real quantization/artifact path. + +## Results + +| ID | Variation | `val_bpb` | Steps | Time (s) | Artifact bytes | +|---:|---|---:|---:|---:|---:| +| 1 | control_standard | 1.1734 | 4415 | 600.1 | 16826913 | +| 2 | factorized_r64 | 2.4396 | 4451 | 604.6 | 16729834 | +| 3 | factorized_r128 | 1.9227 | 4425 | 600.0 | 16918260 | +| 4 | mos_k2_r64 | 2.6167 | 4428 | 600.0 | 16565348 | +| 5 | mos_k4_r64 | 2.7112 | 4149 | 600.0 | 17172588 | +| 6 | mos_k4_r128 | 2.0898 | 4160 | 600.0 | 17943057 | +| 7 | simplex_128 | 4.1069 | 4241 | 600.0 | 10950817 | + +## Main Finding + +The standard tied head outperformed every tested higher-rank alternative on this frontier-aligned 11L baseline. The simplex head reduced artifact size substantially but at an unusable quality cost. The mixture-softmax variants were both worse in score and, for the larger mixtures, larger in artifact size. diff --git a/records/track_non_record_16mb/2026-03-26_HigherRankHeads_11L_Study/install_flash_attn_hopper.sh b/records/track_non_record_16mb/2026-03-26_HigherRankHeads_11L_Study/install_flash_attn_hopper.sh new file mode 100644 index 0000000000..07ce2ddaa1 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_HigherRankHeads_11L_Study/install_flash_attn_hopper.sh @@ -0,0 +1,96 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Build and cache a Hopper-only flash-attn wheel for true FA3 kernels. +# This is for H100-class pods. It avoids repeated generic multi-arch builds. + +if python3 - <<'PY' >/dev/null 2>&1 +mods = ("flash_attn_interface", "hopper.flash_attn_interface") +for mod in mods: + try: + __import__(mod, fromlist=["flash_attn_func"]) + raise SystemExit(0) + except Exception: + pass +raise SystemExit(1) +PY +then + echo "flash-attn Hopper kernels already importable" + exit 0 +fi + +FLASH_ATTN_REPO_URL="${FLASH_ATTN_REPO_URL:-https://github.com/Dao-AILab/flash-attention.git}" +FLASH_ATTN_REF="${FLASH_ATTN_REF:-main}" +FLASH_ATTN_BUILD_ROOT="${FLASH_ATTN_BUILD_ROOT:-/tmp/flash-attention-hopper}" +FLASH_ATTN_WHEEL_DIR="${FLASH_ATTN_WHEEL_DIR:-$HOME/.cache/flash-attn-hopper-wheels}" +FLASH_ATTN_ARCH_LIST="${FLASH_ATTN_ARCH_LIST:-9.0}" +MAX_JOBS="${MAX_JOBS:-8}" +NVCC_THREADS="${NVCC_THREADS:-4}" + +export MAX_JOBS +export NVCC_THREADS +export CMAKE_BUILD_PARALLEL_LEVEL="${CMAKE_BUILD_PARALLEL_LEVEL:-$MAX_JOBS}" +export TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-$FLASH_ATTN_ARCH_LIST}" +export FLASH_ATTENTION_FORCE_BUILD="${FLASH_ATTENTION_FORCE_BUILD:-TRUE}" +export FLASH_ATTENTION_DISABLE_SM80="${FLASH_ATTENTION_DISABLE_SM80:-TRUE}" +export FLASH_ATTENTION_DISABLE_FP16="${FLASH_ATTENTION_DISABLE_FP16:-TRUE}" +export FLASH_ATTENTION_DISABLE_FP8="${FLASH_ATTENTION_DISABLE_FP8:-TRUE}" +export FLASH_ATTENTION_DISABLE_PAGEDKV="${FLASH_ATTENTION_DISABLE_PAGEDKV:-TRUE}" +export FLASH_ATTENTION_DISABLE_APPENDKV="${FLASH_ATTENTION_DISABLE_APPENDKV:-TRUE}" +export FLASH_ATTENTION_DISABLE_LOCAL="${FLASH_ATTENTION_DISABLE_LOCAL:-TRUE}" +export FLASH_ATTENTION_DISABLE_SOFTCAP="${FLASH_ATTENTION_DISABLE_SOFTCAP:-TRUE}" +export FLASH_ATTENTION_DISABLE_CLUSTER="${FLASH_ATTENTION_DISABLE_CLUSTER:-TRUE}" +export FLASH_ATTENTION_DISABLE_HDIM128="${FLASH_ATTENTION_DISABLE_HDIM128:-TRUE}" +export FLASH_ATTENTION_DISABLE_HDIM192="${FLASH_ATTENTION_DISABLE_HDIM192:-TRUE}" +export FLASH_ATTENTION_DISABLE_HDIM256="${FLASH_ATTENTION_DISABLE_HDIM256:-TRUE}" +export FLASH_ATTENTION_DISABLE_HDIMDIFF64="${FLASH_ATTENTION_DISABLE_HDIMDIFF64:-TRUE}" +export FLASH_ATTENTION_DISABLE_HDIMDIFF192="${FLASH_ATTENTION_DISABLE_HDIMDIFF192:-TRUE}" + +mkdir -p "$FLASH_ATTN_WHEEL_DIR" + +LATEST_WHEEL="$(ls -t "$FLASH_ATTN_WHEEL_DIR"/flash_attn_3-*.whl 2>/dev/null | head -1 || true)" +if [ -n "$LATEST_WHEEL" ]; then + python3 -m pip install --no-deps --force-reinstall "$LATEST_WHEEL" + if python3 - <<'PY' >/dev/null 2>&1 +mods = ("flash_attn_interface", "hopper.flash_attn_interface") +for mod in mods: + try: + __import__(mod, fromlist=["flash_attn_func"]) + raise SystemExit(0) + except Exception: + pass +raise SystemExit(1) +PY + then + echo "flash-attn Hopper kernels installed from cached wheel: $(basename "$LATEST_WHEEL")" + exit 0 + fi +fi + +if [ ! -d "$FLASH_ATTN_BUILD_ROOT/.git" ]; then + rm -rf "$FLASH_ATTN_BUILD_ROOT" + git clone --depth 1 --branch "$FLASH_ATTN_REF" "$FLASH_ATTN_REPO_URL" "$FLASH_ATTN_BUILD_ROOT" +else + git -C "$FLASH_ATTN_BUILD_ROOT" fetch --depth 1 origin "$FLASH_ATTN_REF" + git -C "$FLASH_ATTN_BUILD_ROOT" checkout -f FETCH_HEAD +fi + +cd "$FLASH_ATTN_BUILD_ROOT/hopper" +rm -rf build dist +python3 setup.py bdist_wheel + +WHEEL_PATH="$(ls -t dist/*.whl | head -1)" +cp "$WHEEL_PATH" "$FLASH_ATTN_WHEEL_DIR/" +python3 -m pip install --no-deps --force-reinstall "$FLASH_ATTN_WHEEL_DIR/$(basename "$WHEEL_PATH")" + +python3 - <<'PY' +mods = ("flash_attn_interface", "hopper.flash_attn_interface") +for mod in mods: + try: + __import__(mod, fromlist=["flash_attn_func"]) + print(f"flash-attn Hopper install OK via {mod}") + raise SystemExit(0) + except Exception: + pass +raise SystemExit("flash-attn Hopper install failed: FA3 module not importable") +PY diff --git a/records/track_non_record_16mb/2026-03-26_HigherRankHeads_11L_Study/run_higher_rank_heads_study.sh b/records/track_non_record_16mb/2026-03-26_HigherRankHeads_11L_Study/run_higher_rank_heads_study.sh new file mode 100755 index 0000000000..2bebccc881 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_HigherRankHeads_11L_Study/run_higher_rank_heads_study.sh @@ -0,0 +1,67 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +cd "${REPO_ROOT}" +mkdir -p "${SCRIPT_DIR}/logs" + +NUM_GPUS=${NUM_GPUS:-8} +SEED=${SEED:-42} +OMP_NUM_THREADS=${OMP_NUM_THREADS:-1} +MAX_WALLCLOCK_SECONDS=${MAX_WALLCLOCK_SECONDS:-600} +EXPECTED_TRAIN_SHARDS=${EXPECTED_TRAIN_SHARDS:-80} + +actual_train_shards=$(find ./data/datasets/fineweb10B_sp1024 -name 'fineweb_train_*.bin' 2>/dev/null | wc -l | tr -d ' ') +if [ "${actual_train_shards}" -lt "${EXPECTED_TRAIN_SHARDS}" ]; then + echo "Preparing full FineWeb sp1024 dataset (${actual_train_shards}/${EXPECTED_TRAIN_SHARDS} train shards present)" + python3 data/cached_challenge_fineweb.py --variant sp1024 +fi +actual_train_shards=$(find ./data/datasets/fineweb10B_sp1024 -name 'fineweb_train_*.bin' 2>/dev/null | wc -l | tr -d ' ') +if [ "${actual_train_shards}" -lt "${EXPECTED_TRAIN_SHARDS}" ]; then + echo "Expected ${EXPECTED_TRAIN_SHARDS} train shards but found ${actual_train_shards}" >&2 + exit 1 +fi + +bash "${SCRIPT_DIR}/install_flash_attn_hopper.sh" + +base_env() { + env \ + PYTHONPATH="${REPO_ROOT}:${PYTHONPATH:-}" \ + OMP_NUM_THREADS="$OMP_NUM_THREADS" \ + SEED="$SEED" \ + NUM_LAYERS=11 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=3 \ + USE_EMA=1 EMA_FROM_INIT=1 EMA_ALPHA=0.997 XSA_LAST_N=4 \ + SMEAR_GATE=1 BIGRAM_HASH=1 BIGRAM_VOCAB_SIZE=2048 BIGRAM_DIM=128 \ + USE_NTK_ROPE=1 ROPE_DIMS=16 LN_SCALE=1 \ + VALUE_EMBED=1 VE_DIM=128 VE_LAYERS=9,10 \ + LATE_QAT=1 QAT_THRESHOLD=0.15 QAT_BITS=6 USE_QAT=0 \ + MATRIX_LR=0.025 SCALAR_LR=0.025 TIED_EMBED_LR=0.035 \ + MUON_MOMENTUM=0.99 MUON_WEIGHT_DECAY=0.04 ADAM_WD=0.04 \ + INIT_TYPE=ortho WARMDOWN_ITERS=3500 \ + TRAIN_BATCH_TOKENS=786432 TRAIN_SEQ_LEN=2048 \ + EVAL_STRIDE=64 VAL_LOSS_EVERY=0 TRAIN_LOG_EVERY=200 \ + COMPRESS_ALGO=zstd QUANT_BITS_MIDDLE=6 QUANT_BITS_MLP=6 GRAD_GUIDED_QUANT=0 \ + MAX_WALLCLOCK_SECONDS="$MAX_WALLCLOCK_SECONDS" \ + USE_FA3=1 FLASH_ATTN_BACKEND=fa3 FLASH_ATTN_STRICT=1 FLASH_ATTN_ARCH_LIST=9.0 \ + DATA_PATH=./data/datasets/fineweb10B_sp1024 TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model VOCAB_SIZE=1024 \ + SKIP_COMPILE=0 SKIP_QUANT=0 +} + +run_one() { + local slug="$1"; shift + local run_id="higher_rank_heads_${slug}_seed${SEED}" + base_env RUN_ID="$run_id" "$@" \ + torchrun --standalone --nproc_per_node="$NUM_GPUS" "${SCRIPT_DIR}/train_gpt.py" + if [ -f "logs/${run_id}.txt" ]; then + cp "logs/${run_id}.txt" "${SCRIPT_DIR}/logs/${slug}.log" + fi +} + +run_one H0_control_standard HEAD_TYPE=standard RANK_DIM=0 +run_one H1_factorized_r64 HEAD_TYPE=standard RANK_DIM=64 +run_one H2_factorized_r128 HEAD_TYPE=standard RANK_DIM=128 +run_one H3_mos_k2_r64 HEAD_TYPE=mixture_softmax MIXTURE_SOFTMAX_K=2 MIXTURE_RANK_DIM=64 RANK_DIM=0 +run_one H4_mos_k4_r64 HEAD_TYPE=mixture_softmax MIXTURE_SOFTMAX_K=4 MIXTURE_RANK_DIM=64 RANK_DIM=0 +run_one H5_mos_k4_r128 HEAD_TYPE=mixture_softmax MIXTURE_SOFTMAX_K=4 MIXTURE_RANK_DIM=128 RANK_DIM=0 +run_one H6_simplex_128 HEAD_TYPE=simplex SIMPLEX_DIM=128 RANK_DIM=0 diff --git a/records/track_non_record_16mb/2026-03-26_HigherRankHeads_11L_Study/train_gpt.py b/records/track_non_record_16mb/2026-03-26_HigherRankHeads_11L_Study/train_gpt.py new file mode 100644 index 0000000000..6b4a30a6f2 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_HigherRankHeads_11L_Study/train_gpt.py @@ -0,0 +1,2891 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# Conditionally import research modules for experimental primitives. +_ATTN = os.environ.get("ATTENTION_TYPE", "standard") +_MLP = os.environ.get("MLP_TYPE", "standard") +_RESEARCH = _ATTN != "standard" or _MLP != "standard" or os.environ.get("USE_TERNARY") == "1" or os.environ.get("OPTIMIZER") in ("conda", "qes", "matrix_free_cg") or os.environ.get("USE_QTT") == "1" or os.environ.get("USE_HASH_EMBEDDING") == "1" or os.environ.get("LIQUID_TYPE") == "symplectic" +if _RESEARCH: + try: + from research_modules import get_attention_cls, get_mlp_cls, TernaryLinear, CondaOptimizer, QuanticsLinear, QESOptimizer, MatrixFreeCGOptimizer, HashEmbedding, SymplecticLiquidLayer + except ImportError: + print("WARNING: research_modules.py not found, falling back to standard", file=__import__('sys').stderr) + _RESEARCH = False + +# Flash Attention backend selection: +# - auto: try Hopper FA3 first, then FA2, else SDPA +# - fa3: require true Hopper kernels unless FLASH_ATTN_STRICT=0 +# - fa2: require flash-attn v2 kernels unless FLASH_ATTN_STRICT=0 +# - sdpa: disable flash-attn direct kernels +_FLASH_ATTN_BACKEND = os.environ.get("FLASH_ATTN_BACKEND", "auto").lower() +_FLASH_ATTN_STRICT = os.environ.get("FLASH_ATTN_STRICT", "0") == "1" +_FA3_AVAILABLE = False +_FA_VERSION = "sdpa" +if os.environ.get("USE_FA3", "0") == "1" and _FLASH_ATTN_BACKEND != "sdpa": + _import_error = None + if _FLASH_ATTN_BACKEND in ("auto", "fa3"): + for _mod in ("flash_attn_interface", "hopper.flash_attn_interface"): + try: + _fa3_func = __import__(_mod, fromlist=["flash_attn_func"]).flash_attn_func + _FA3_AVAILABLE = True + _FA_VERSION = "fa3" + break + except Exception as e: + _import_error = e + if not _FA3_AVAILABLE and _FLASH_ATTN_BACKEND in ("auto", "fa2"): + try: + _fa3_func = __import__("flash_attn.flash_attn_interface", fromlist=["flash_attn_func"]).flash_attn_func + _FA3_AVAILABLE = True + _FA_VERSION = "fa2" + except Exception as e: + _import_error = e + if _FLASH_ATTN_BACKEND == "fa3" and _FA_VERSION != "fa3": + msg = "flash_attention: requested true fa3 but Hopper kernels were not importable" + if _FLASH_ATTN_STRICT: + raise RuntimeError(msg) from _import_error + print(f"WARNING: {msg}; falling back to sdpa", file=sys.stderr) + _FA3_AVAILABLE = False + _FA_VERSION = "sdpa" + elif _FLASH_ATTN_BACKEND == "fa2" and _FA_VERSION != "fa2": + msg = "flash_attention: requested fa2 but flash-attn v2 kernels were not importable" + if _FLASH_ATTN_STRICT: + raise RuntimeError(msg) from _import_error + print(f"WARNING: {msg}; falling back to sdpa", file=sys.stderr) + _FA3_AVAILABLE = False + _FA_VERSION = "sdpa" + print( + f"flash_attention: requested={_FLASH_ATTN_BACKEND} strict={int(_FLASH_ATTN_STRICT)} active={_FA_VERSION}", + file=sys.stderr, + ) + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Tokenizer selection: sp1024 (default) or custom8192 (larger vocab, better compression). + _tokenizer_type = os.environ.get("TOKENIZER_TYPE", "sp1024") + if _tokenizer_type == "custom8192": + _data_default = "./data/datasets/fineweb10B_custom8192" + _tok_default = "./data/tokenizers/fineweb_8192_bpe.model" + _vocab_default = 8192 + else: + _data_default = "./data/datasets/fineweb10B_sp1024" + _tok_default = "./data/tokenizers/fineweb_1024_bpe.model" + _vocab_default = 1024 + + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", _data_default) + _train_on_val_early = os.environ.get("TRAIN_ON_VAL", "0") == "1" + train_files = os.path.join(data_path, "fineweb_val_*.bin" if _train_on_val_early else "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", _tok_default) + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", str(_vocab_default))) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Research extensions: sparse attention and recurrent weight tying. + # WINDOW_SIZE=0 → dense attention (default). WINDOW_SIZE=128 → sliding window of 128. + window_size = int(os.environ.get("WINDOW_SIZE", 0)) + # NUM_PHYSICAL_LAYERS controls weight tying: if < num_layers, blocks are cycled. + # e.g. num_layers=9, num_physical_layers=3 → 3 unique blocks, each reused 3 times. + num_physical_layers = int(os.environ.get("NUM_PHYSICAL_LAYERS", 0)) # 0 = same as num_layers + + # Research extensions: regularizers (env-var gated, 0.0 = off). + lambda_spectral = float(os.environ.get("LAMBDA_SPECTRAL", 0.0)) # Spectral radius anisotropy penalty. + lambda_tube = float(os.environ.get("LAMBDA_TUBE", 0.0)) # Semantic tube ODE smoothness penalty. + lambda_coherence = float(os.environ.get("LAMBDA_COHERENCE", 0.0)) # Cross-layer weight coherence penalty. + lambda_vicreg = float(os.environ.get("LAMBDA_VICREG", 0.0)) # VICReg variance-covariance regularizer. + # Sweep 006b: advanced quantization (post-training, zero training cost). + quant_error_diffusion = bool(int(os.environ.get("QUANT_ERROR_DIFFUSION", "0"))) # Delta-Sigma error diffusion during quant. + quant_hadamard = bool(int(os.environ.get("QUANT_HADAMARD", "0"))) # Hadamard rotation before quant (QuIP#-inspired). + # Sweep 006: novel techniques (env-var gated). + activation_type = os.environ.get("ACTIVATION_TYPE", "relu2") # relu2|chebyshev|star_relu + chebyshev_order = int(os.environ.get("CHEBYSHEV_ORDER", "4")) # Polynomial order for Chebyshev activation. + residual_mag_norm = bool(int(os.environ.get("RESIDUAL_MAG_NORM", "0"))) # Normalize residual RMS per block. + value_embed = bool(int(os.environ.get("VALUE_EMBED", "0"))) # Value Embeddings (VE128). + ve_dim = int(os.environ.get("VE_DIM", "128")) # VE embedding dimension. + ve_layers = os.environ.get("VE_LAYERS", "") # Comma-sep layer indices for VE, e.g. "7,9,10". + attention_pattern = os.environ.get("ATTENTION_PATTERN", "standard") # standard|hybrid_linear + softmax_anchor_layers = os.environ.get("SOFTMAX_ANCHOR_LAYERS", "") # Layers that keep full softmax in hybrid. + linear_chunk_size = int(os.environ.get("LINEAR_CHUNK_SIZE", "64")) # Chunk size for chunked linear attention. + token_saliency = bool(int(os.environ.get("TOKEN_SALIENCY", "0"))) # Per-token soft routing gate. + saliency_skip_layers = os.environ.get("SALIENCY_SKIP_LAYERS", "") # Layers where saliency applies. + saliency_chunk_sizes = os.environ.get("SALIENCY_CHUNK_SIZES", "") # Hierarchical chunk saliency scales, e.g. "4,16,64". + hypergraph_lift = bool(int(os.environ.get("HYPERGRAPH_LIFT", "0"))) # Poset-inspired multi-scale hypergraph lift. + hypergraph_layers = os.environ.get("HYPERGRAPH_LAYERS", "") # Layers where hypergraph lift applies. + hypergraph_scales = os.environ.get("HYPERGRAPH_SCALES", "2,4,8") # Chunk scales for graded-poset lifting. + # Research extensions: hybrid equivalence attention (0 = off). + num_anchor_tokens = int(os.environ.get("NUM_ANCHOR_TOKENS", 0)) # Global anchor tokens for hybrid attn. + # Research extensions: adaptive liquid layers (0 = off). + liquid_depth = int(os.environ.get("LIQUID_DEPTH", 0)) # Total logical depth for depth-routed layers. + router_hidden = int(os.environ.get("ROUTER_HIDDEN", 0)) # DepthRouter hidden dim (0 = model_dim//4). + # Research extensions: factorized output head (0 = standard head). + rank_dim = int(os.environ.get("RANK_DIM", 0)) # Low-rank head: d_model → rank → vocab. + head_type = os.environ.get("HEAD_TYPE", "standard") # standard|mixture_softmax|simplex + mixture_softmax_k = int(os.environ.get("MIXTURE_SOFTMAX_K", "4")) # Number of expert heads in mixture softmax. + mixture_rank_dim = int(os.environ.get("MIXTURE_RANK_DIM", "128")) # Bottleneck dim per expert for mixture head. + simplex_dim = int(os.environ.get("SIMPLEX_DIM", "256")) # Simplex bottleneck size for simplex head. + # Research extensions: alternative attention/MLP/optimizer/quantization. + attention_type = os.environ.get("ATTENTION_TYPE", "standard") # standard|delta|fox|log_linear|allmem|performer|mla + mlp_type = os.environ.get("MLP_TYPE", "standard") # standard|kan|hyper + use_ternary = bool(int(os.environ.get("USE_TERNARY", "0"))) # 1.58-bit ternary weights + optimizer_type = os.environ.get("OPTIMIZER", "muon") # muon|conda|adam|qes|matrix_free_cg + use_qtt = bool(int(os.environ.get("USE_QTT", "0"))) # Quantics Tensor Train embedding + use_hash_embedding = bool(int(os.environ.get("USE_HASH_EMBEDDING", "0"))) # HashEmbedding (zero-param) + tt_rank = int(os.environ.get("TT_RANK", 8)) # QTT rank + num_cores = int(os.environ.get("NUM_CORES", 12)) # QTT core count + # Symplectic liquid layer (Hamiltonian recurrence). + liquid_type = os.environ.get("LIQUID_TYPE", "standard") # standard|symplectic + # Initialization strategy. + init_type = os.environ.get("INIT_TYPE", "default") # default|overtone + # Competitive hacks (from leaderboard analysis). + train_on_val = bool(int(os.environ.get("TRAIN_ON_VAL", "0"))) # Train on val split (allowed per rules). + eval_stride = int(os.environ.get("EVAL_STRIDE", 0)) # Sliding window eval stride (0 = standard). + data_curriculum = os.environ.get("DATA_CURRICULUM", "") # "length_sort" = train easy-to-hard by shard size. + loss_type = os.environ.get("LOSS_TYPE", "cross_entropy") # "jepa_latent" for JEPA L2 latent loss. + latent_dim = int(os.environ.get("LATENT_DIM", "256")) # Latent projection dimension for JEPA. + patch_size = int(os.environ.get("PATCH_SIZE", "1")) # SP-4096: compress N tokens → 1 vector (4 = 4x context). + log_tube_metrics = bool(int(os.environ.get("LOG_TUBE_METRICS", "0"))) # Log STP-style drift/curvature/isotropy metrics. + + # Competition leader techniques. + compress_algo = os.environ.get("COMPRESS_ALGO", "zlib") # zlib|zstd (zstd-22 saves ~1-2MB) + quant_bits_middle = int(os.environ.get("QUANT_BITS_MIDDLE", "8")) # 6 for int6 middle layers + quant_bits_mlp = int(os.environ.get("QUANT_BITS_MLP", "0")) # 0 = same as quant_bits_middle; 5 for int5 MLP (PR#180) + prune_fraction = float(os.environ.get("PRUNE_FRACTION", "0.0")) # Magnitude pruning: zero smallest N% of 2D weights (PR#205: 0.02) + optimizer_variant = os.environ.get("OPTIMIZER_VARIANT", "standard") # standard|normuon + lora_ttt = bool(int(os.environ.get("LORA_TTT", "0"))) # LoRA test-time training during eval + lora_rank = int(os.environ.get("LORA_RANK", "4")) + ttt_lr = float(os.environ.get("TTT_LR", "0.01")) + ttt_steps = int(os.environ.get("TTT_STEPS", "3")) + # Full-weight TTT (PR#254 SOTA: 1.1303 bpb): SGD on validation data before quantization. + # Adapts ALL model weights (except frozen early blocks) to the val distribution. + # Applied after EMA/pruning, before quantization — the adapted weights get compressed. + use_ttt = bool(int(os.environ.get("USE_TTT", "0"))) + ttt_sgd_lr = float(os.environ.get("TTT_SGD_LR", "0.002")) # PR#254: 0.002, PR#267: 0.004 + ttt_sgd_momentum = float(os.environ.get("TTT_SGD_MOMENTUM", "0.9")) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", "3")) # Passes over val data + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", "2")) # Freeze first N blocks (PR#254: 2) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", "1.0")) # Gradient clip norm + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", "32")) # Batch size in sequences + tokenizer_type = os.environ.get("TOKENIZER_TYPE", "sp1024") # sp1024|custom8192 + use_loop_embed = bool(int(os.environ.get("USE_LOOP_EMBED", "0"))) # Per-loop depth embedding for recurrence + use_fa3 = bool(int(os.environ.get("USE_FA3", "0"))) # Flash Attention 3 (PR#198 SOTA) + use_qat = bool(int(os.environ.get("USE_QAT", "0"))) # Quantization-aware training (STE) + qat_bits = int(os.environ.get("QAT_BITS", "8")) # QAT bit width (8 or 6) + qat_start_frac = float(os.environ.get("QAT_START_FRAC", "0.0")) # Delay QAT until this fraction of training (0.25 = 25%) + use_bigram_gate = bool(int(os.environ.get("USE_BIGRAM_GATE", "0"))) # Legacy parameter-free bigram (deprecated) + bigram_hash_bins = int(os.environ.get("BIGRAM_HASH_BINS", "4096")) + # SOTA techniques from competition leaders (PR#198, #194, #180): + smear_gate = bool(int(os.environ.get("SMEAR_GATE", "0"))) # Per-dim learned gate blending prev+current token + bigram_hash = bool(int(os.environ.get("BIGRAM_HASH", "0"))) # Learned bigram embedding with per-layer lambdas + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", "2048")) # Hash buckets for bigram embedding + bigram_dim = int(os.environ.get("BIGRAM_DIM", "128")) # Bigram embedding dim (projected to model_dim if != model_dim) + use_swa = bool(int(os.environ.get("USE_SWA", "0"))) # Stochastic Weight Averaging during warmdown + swa_start_ratio = float(os.environ.get("SWA_START_RATIO", "0.5")) # Start SWA when lr_scale < this + swa_every = int(os.environ.get("SWA_EVERY", "50")) # Collect SWA checkpoint every N steps + # EMA (Exponential Moving Average) — superior to SWA for short training runs (PR#201, #223). + # When use_ema=1, SWA is ignored. EMA biases toward recent weights, avoiding dilution from early checkpoints. + use_ema = bool(int(os.environ.get("USE_EMA", "0"))) + ema_alpha = float(os.environ.get("EMA_ALPHA", "0.999")) # Decay factor: higher = slower adaptation + ema_start_ratio = float(os.environ.get("EMA_START_RATIO", "0.6")) # Start EMA when lr_scale < this + ema_from_init = bool(int(os.environ.get("EMA_FROM_INIT", "0"))) # PR#315: init EMA at step 0, update every step + # NTK-RoPE: scale RoPE base frequency for longer sequences (PR#198, #206). + # Auto-enabled when USE_NTK_ROPE=1 and train_seq_len > 1024. + use_ntk_rope = bool(int(os.environ.get("USE_NTK_ROPE", "0"))) + # XSA (Exclusive Self-Attention): remove self-value bias from attention outputs (PR#265, #315). + # Applied to last N layers only. GQA-aware implementation avoids repeat_interleave overhead. + xsa_last_n = int(os.environ.get("XSA_LAST_N", "0")) # 0 = off, 3-4 = PR#265/#315 + # Partial RoPE: only apply rotary embeddings to first N dims of each head (PR#315). + # 0 = full (all head_dim dims rotated). 16 = only first 16 of 64 dims. Rest position-free. + rope_dims = int(os.environ.get("ROPE_DIMS", "0")) # 0 = all dims + # LN Scale: scale RMSNorm output by 1/sqrt(layer_idx+1) to stabilize deep layers (PR#315). + ln_scale = bool(int(os.environ.get("LN_SCALE", "0"))) + # Late QAT: enable STE fake-quantization only in last ~4% of training (PR#315). + # Better than full QAT: no training quality loss, but model learns quantization-friendly weights. + late_qat = bool(int(os.environ.get("LATE_QAT", "0"))) + qat_threshold = float(os.environ.get("QAT_THRESHOLD", "0.1")) # Enable when lr_scale < this + # FP16 embeddings: keep tied embeddings in FP16 during quantization (PR#180). + # Critical when using int5 MLP — prevents logit noise from compounding with MLP noise. + fp16_embeddings = bool(int(os.environ.get("FP16_EMBEDDINGS", "0"))) + # Gradient-guided adaptive quantization (PR#332): assign int5/int6/int7 per tensor + # based on accumulated gradient magnitude during warmdown. High-gradient = more bits. + grad_guided_quant = bool(int(os.environ.get("GRAD_GUIDED_QUANT", "0"))) + # Backout connection (PR#339): subtract scaled mid-layer hidden state from output. + # Denoises the residual stream by removing processing artifacts. + backout_connection = bool(int(os.environ.get("BACKOUT_CONNECTION", "0"))) + backout_layer = int(os.environ.get("BACKOUT_LAYER", "5")) # Which layer's output to subtract + # Spectral warmdown scheduling (NOVEL): per-layer LR scaling by spectral norm. + # Layers with higher spectral norm (less converged) get proportionally more LR during warmdown. + spectral_warmdown = bool(int(os.environ.get("SPECTRAL_WARMDOWN", "0"))) + # Context curriculum: start training at shorter seq_len, ramp to full (PR#203). + # Yields ~25% more gradient steps by doing early training at seq512. + # Format: "512:0.4,1024:0.7,2048:1.0" means seq512 for first 40%, seq1024 until 70%, seq2048 after. + seq_curriculum = os.environ.get("SEQ_CURRICULUM", "") # e.g. "512:0.4,1024:0.7,2048:1.0" + + # Dev/mini-run flags. + skip_quant = bool(int(os.environ.get("SKIP_QUANT", "0"))) # Skip int8+zlib for fast local runs. + dev_mode = bool(int(os.environ.get("DEV_MODE", "0"))) # Allow MPS/CPU (no CUDA required). + skip_compile = bool(int(os.environ.get("SKIP_COMPILE", "0"))) # Skip torch.compile (Triton OOM workaround). + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + muon_weight_decay = float(os.environ.get("MUON_WEIGHT_DECAY", 0.0)) # Decoupled WD on Muon (SOTA: 0.038-0.04) + adam_weight_decay = float(os.environ.get("ADAM_WD", 0.0)) # AdamW WD for embeds/scalars (SOTA: 0.01-0.04) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, variant: str = "standard", weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + self.variant = variant + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + if self.variant == "normuon": + g = g / g.norm(dim=0, keepdim=True).clamp(min=1e-8) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + weight_decay = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if weight_decay > 0: + p.mul_(1 - lr * weight_decay) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +class LoRAAdapter(nn.Module): + """Low-rank adapter for test-time training during eval.""" + def __init__(self, linear: nn.Module, rank: int = 4, alpha: float = 8.0): + super().__init__() + self.linear = linear + self.scale = alpha / rank + din = linear.in_features + dout = linear.out_features + self.A = nn.Parameter(torch.zeros(din, rank, device=linear.weight.device, dtype=torch.bfloat16)) + self.B = nn.Parameter(torch.zeros(rank, dout, device=linear.weight.device, dtype=torch.bfloat16)) + nn.init.normal_(self.A, std=0.02) + + def forward(self, x: Tensor) -> Tensor: + return self.linear(x) + (x.to(self.A.dtype) @ self.A @ self.B) * self.scale + + +def _apply_lora_ttt(model: nn.Module, val_tokens: Tensor, device: torch.device, args: Hyperparameters) -> list: + """Wrap Q/V projections with LoRA adapters, run TTT gradient steps, return originals for cleanup.""" + base = model.module if hasattr(model, "module") else model + blocks = base.blocks if hasattr(base, "blocks") else [] + originals: list[tuple] = [] + adapters: list[LoRAAdapter] = [] + + for block in blocks: + attn = block.attn if hasattr(block, "attn") else None + if attn is None: + continue + for proj_name in ("c_q", "c_v"): + if hasattr(attn, proj_name): + orig = getattr(attn, proj_name) + adapter = LoRAAdapter(orig, rank=args.lora_rank).to(device) + setattr(attn, proj_name, adapter) + originals.append((attn, proj_name, orig)) + adapters.append(adapter) + + if not adapters: + return [] + + # Collect LoRA parameters and create optimizer + lora_params = [] + for a in adapters: + lora_params.extend([a.A, a.B]) + opt = torch.optim.Adam(lora_params, lr=args.ttt_lr) + + # TTT: take gradient steps on a sample of val tokens + seq_len = args.train_seq_len + sample_size = min(seq_len * 4, val_tokens.numel() - 1) + sample = val_tokens[:sample_size + 1].to(device=device, dtype=torch.int64) + x = sample[:-1].unsqueeze(0) + y = sample[1:].unsqueeze(0) + + model.train() + for _ in range(args.ttt_steps): + opt.zero_grad() + with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + loss.backward() + opt.step() + model.eval() + + return originals + + +def _remove_lora(originals: list[tuple]) -> None: + """Restore original modules after TTT.""" + for parent, name, orig in originals: + setattr(parent, name, orig) + + +def apply_full_weight_ttt( + base_model: nn.Module, + val_tokens: Tensor, + device: torch.device, + args: Hyperparameters, + rank: int = 0, + world_size: int = 1, + log_fn=None, +) -> None: + """Full-weight Test-Time Training: adapt model weights on validation data via SGD. + + Based on PR#254 (SOTA 1.1303 bpb): run SGD on the val split for N epochs, + freezing early blocks for stability. The adapted weights are then quantized. + This gives the model a chance to tune itself to the specific val distribution. + """ + if log_fn is None: + log_fn = lambda *a, **kw: None + + # Freeze early blocks (PR#254: first 2 blocks) + frozen_params = set() + blocks = base_model.blocks if hasattr(base_model, "blocks") else [] + for i, block in enumerate(blocks): + if i < args.ttt_freeze_blocks: + for p in block.parameters(): + p.requires_grad_(False) + frozen_params.add(id(p)) + + # Collect trainable parameters + ttt_params = [p for p in base_model.parameters() if p.requires_grad and id(p) not in frozen_params] + total_params = sum(p.numel() for p in ttt_params) + log_fn(f"ttt:starting lr={args.ttt_sgd_lr} epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks} trainable_params={total_params}") + + # Create SGD optimizer (PR#254: lr=0.002, momentum=0.9) + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_sgd_lr, momentum=args.ttt_sgd_momentum) + + # Distribute val data across ranks (like DDP). Each rank processes 1/N of the + # data, then we all_reduce gradients at each step for synchronized training. + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() + total_seqs = (total_tokens - 1) // seq_len + distributed = dist.is_available() and dist.is_initialized() and world_size > 1 + # Each rank gets its slice of sequences + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + log_fn(f"ttt:rank={rank} seqs={seq_start}-{seq_end} ({seq_end - seq_start} of {total_seqs})") + + base_model.train() + t0_ttt = time.perf_counter() + for epoch in range(args.ttt_epochs): + epoch_loss = 0.0 + epoch_batches = 0 + for batch_start in range(seq_start, seq_end, args.ttt_batch_seqs): + batch_end = min(batch_start + args.ttt_batch_seqs, seq_end) + raw_start = batch_start * seq_len + raw_end = batch_end * seq_len + 1 + if raw_end > total_tokens: + break + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True): + loss = base_model(x, y) + loss.backward() + # Sync gradients across ranks before optimizer step (like DDP) + if distributed: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.SUM) + p.grad /= world_size + if args.ttt_grad_clip > 0: + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + epoch_loss += loss.item() + epoch_batches += 1 + + avg_loss = epoch_loss / max(epoch_batches, 1) + log_fn(f"ttt:epoch={epoch+1}/{args.ttt_epochs} avg_loss={avg_loss:.4f}") + + ttt_secs = time.perf_counter() - t0_ttt + + # Unfreeze all blocks + for block in blocks: + for p in block.parameters(): + p.requires_grad_(True) + + base_model.eval() + log_fn(f"ttt:complete epochs={args.ttt_epochs} time={ttt_secs:.1f}s") + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + _accum_dtype = torch.float32 if device.type == "mps" else torch.float64 + val_loss_sum = torch.zeros((), device=device, dtype=_accum_dtype) + val_token_count = torch.zeros((), device=device, dtype=_accum_dtype) + val_byte_count = torch.zeros((), device=device, dtype=_accum_dtype) + + # LoRA TTT: adapt Q/V projections on val tokens before scoring + _lora_originals = _apply_lora_ttt(model, val_tokens, device, args) if args.lora_ttt else [] + + model.eval() + _stride = args.eval_stride + with torch.inference_mode(): + if _stride > 0: + # Strided sliding-window evaluation: each token scored with max context. + # Only count loss on the last `stride` tokens of each window. + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() + _max_slide_iters = int(os.environ.get("MAX_SLIDE_ITERS", "4096")) # Cap sliding window iters + _slide_count = 0 + for pos in range(seq_start * seq_len, min(seq_end * seq_len, total_tokens - seq_len - 1), _stride): + if _slide_count >= _max_slide_iters: + break + _slide_count += 1 + end = min(pos + seq_len + 1, total_tokens) + local = val_tokens[pos:end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].unsqueeze(0) # (1, seq_len) + y = local[1:].unsqueeze(0) + with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True): + # Get per-token loss by calling forward with reduction='none' if supported, + # otherwise use the standard mean loss as approximation + batch_loss = model(x, y).detach() + # Count only the last `stride` tokens (they have full context) + score_count = min(_stride, x.size(1)) + val_loss_sum += batch_loss.to(_accum_dtype) * score_count + val_token_count += score_count + # Byte counting for the scored tokens + scored_prev = x[0, -score_count:].reshape(-1) + scored_tgt = y[0, -score_count:].reshape(-1) + tb = base_bytes_lut[scored_tgt].to(dtype=torch.int16) + tb += (has_leading_space_lut[scored_tgt] & ~is_boundary_token_lut[scored_prev]).to(dtype=torch.int16) + val_byte_count += tb.to(_accum_dtype).sum() + else: + # Standard non-overlapping evaluation + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(_accum_dtype) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(_accum_dtype).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + # Remove LoRA adapters before returning + if _lora_originals: + _remove_lora(_lora_originals) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor, bits: int = 8, error_diffusion: bool = False) -> tuple[Tensor, Tensor]: + max_val = (1 << (bits - 1)) - 1 # 127 for int8, 31 for int6, 15 for int5 (SOTA: PR#180 uses full range) + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / float(max_val)).clamp_min(1.0 / float(max_val)) + if error_diffusion: + # Delta-Sigma error diffusion: accumulate rounding error across columns. + # Preserves aggregate row statistics by propagating quantization residual. + # Ref: QES (2602.03120) ΔΣ concept, adapted for post-training weight quant. + normalized = clipped / scale[:, None] + q = torch.zeros_like(normalized) + error = torch.zeros(t32.shape[0], device=t32.device, dtype=torch.float32) + for j in range(t32.shape[1]): + val = normalized[:, j] + error + rounded = torch.clamp(torch.round(val), -max_val, max_val) + error = val - rounded + q[:, j] = rounded + return q.to(torch.int8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + q = torch.clamp(torch.round(clipped / scale[:, None]), -max_val, max_val).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale (no error diffusion needed). + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / float(max_val) if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -max_val, max_val).to(torch.int8).contiguous() + return q, scale + + +def _random_orthogonal_seeded(n: int, seed: int = 42) -> Tensor: + """Generate deterministic random orthogonal matrix via QR decomposition. + Using a fixed seed means the rotation can be reconstructed at dequant time + without storing the matrix (zero storage overhead).""" + gen = torch.Generator().manual_seed(seed) + A = torch.randn(n, n, generator=gen) + Q, R = torch.linalg.qr(A) + Q = Q * torch.sign(torch.diag(R)).unsqueeze(0) # proper orthogonal + return Q + +def _is_middle_layer(name: str, max_block_idx: int) -> bool: + """Returns True for block layers that are neither the first nor the last.""" + import re as _re + m = _re.match(r"blocks\.(\d+)\.", name) + if not m or max_block_idx <= 1: + return False + idx = int(m.group(1)) + return 0 < idx < max_block_idx + +def _is_mlp_layer(name: str) -> bool: + """Returns True for MLP weight tensors (used for int5 mixed quantization, PR#180).""" + return ".mlp." in name + +def quantize_state_dict_int8(state_dict: dict[str, Tensor], quant_bits_middle: int = 8, quant_bits_mlp: int = 0, fp16_embeddings: bool = False, grad_bit_map: dict[str, int] | None = None, error_diffusion: bool = False, hadamard: bool = False): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + # Precompute max block index for middle-layer detection + import re as _re + _block_indices = [int(_re.match(r"blocks\.(\d+)\.", k).group(1)) for k in state_dict if _re.match(r"blocks\.(\d+)\.", k)] + _max_block_idx = max(_block_indices) if _block_indices else 0 + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # FP16 embeddings: keep tok_emb/lm_head in FP16 to prevent logit noise (PR#180). + # Critical when using int5 MLP — avoids compounding quantization noise. + if fp16_embeddings and ("tok_emb" in name or "lm_head" in name): + kept = t.to(dtype=torch.float16).contiguous() + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + if grad_bit_map is not None and name in grad_bit_map: + bits = grad_bit_map[name] # gradient-guided per-tensor bit assignment + elif _is_middle_layer(name, _max_block_idx): + if quant_bits_mlp > 0 and _is_mlp_layer(name): + bits = quant_bits_mlp # int5 for MLP (PR#180) + elif quant_bits_middle < 8: + bits = quant_bits_middle # int6 for attention + other middle layers + else: + bits = 8 + else: + bits = 8 + # Hadamard rotation: rotate columns before quantizing to spread outliers (QuIP#-inspired). + # Uses seeded RNG so rotation can be reconstructed at dequant (zero storage overhead). + _rotated = False + if hadamard and t.ndim == 2 and t.shape[1] >= 64: + _Q = _random_orthogonal_seeded(t.shape[1], seed=42 + hash(name) % 10000) + t = t @ _Q.to(device=t.device, dtype=t.dtype) + _rotated = True + q, s = quantize_float_tensor(t, bits=bits, error_diffusion=error_diffusion) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + if _rotated: + qmeta.setdefault(name, {})["hadamard"] = True + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def compute_grad_bit_map(grad_sensitivity: dict[str, float], quant_bits_middle: int = 6, target_avg_bits: float = 6.0) -> dict[str, int]: + """Gradient-guided adaptive quantization (PR#332 inspired). + Rank tensors by accumulated gradient magnitude. Assign fewer bits to low-sensitivity + tensors (int5) and more bits to high-sensitivity tensors (int7). Target avg ~6 bits.""" + if not grad_sensitivity: + return {} + names = sorted(grad_sensitivity.keys()) + sensitivities = [grad_sensitivity[n] for n in names] + total = sum(sensitivities) + if total <= 0: + return {} + # Rank by sensitivity: bottom 30% get int5, middle 40% get int6, top 30% get int7 + indexed = sorted(enumerate(sensitivities), key=lambda x: x[1]) + n = len(indexed) + bit_map: dict[str, int] = {} + for rank_pos, (orig_idx, _) in enumerate(indexed): + frac = rank_pos / max(n - 1, 1) + if frac < 0.3: + bit_map[names[orig_idx]] = max(quant_bits_middle - 1, 4) # int5 + elif frac < 0.7: + bit_map[names[orig_idx]] = quant_bits_middle # int6 + else: + bit_map[names[orig_idx]] = min(quant_bits_middle + 1, 8) # int7 + return bit_map + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + t_deq = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))) + else: + scale = float(s.item()) + t_deq = q.float() * scale + # Hadamard rotation inverse: reconstruct the same rotation and apply transpose. + if qmeta.get(name, {}).get("hadamard"): + _Q = _random_orthogonal_seeded(t_deq.shape[1], seed=42 + hash(name) % 10000) + t_deq = t_deq @ _Q.T + out[name] = t_deq.to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def prune_state_dict(state_dict: dict[str, Tensor], fraction: float) -> dict[str, Tensor]: + """Zero out the smallest `fraction` of weights by magnitude in 2D float tensors (PR#205: 2%).""" + if fraction <= 0: + return state_dict + pruned = {} + total_pruned = total_eligible = 0 + for name, t in state_dict.items(): + if t.ndim == 2 and t.is_floating_point(): + threshold = torch.quantile(t.abs().flatten().float(), fraction) + mask = t.abs() >= threshold + pruned[name] = t * mask + total_pruned += int((~mask).sum().item()) + total_eligible += t.numel() + else: + pruned[name] = t + return pruned + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + # Note: sequence packing (USE_PACKING) is the default behavior — the loader reads a + # contiguous token stream with no padding. USE_PACKING=1 is a no-op confirmation. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device, + curriculum: str = ""): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern, curriculum=curriculum) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Pre-allocates cos/sin tables at init for max_seq_len. + # CRITICAL: buffers are registered (stable object IDs) so torch._dynamo + # never sees a guard failure when seq_len changes mid-training (curriculum). + def __init__(self, dim: int, base: float = 10000.0, ntk_seq_len: int = 0, max_seq_len: int = 2048, rope_dims: int = 0): + super().__init__() + # Partial RoPE (PR#315): if rope_dims > 0, only compute rotary for that many dims. + self.rope_dims = rope_dims if rope_dims > 0 else dim + actual_dim = self.rope_dims # Rotary table size matches the dims we actually rotate + # NTK-RoPE: scale base frequency when training at longer seq_len (PR#198, #206). + # base_ntk = base * (seq_len / 1024) ^ (dim / (dim - 2)) + if ntk_seq_len > 1024: + alpha = ntk_seq_len / 1024.0 + base = base * (alpha ** (actual_dim / max(actual_dim - 2, 1))) + inv_freq = 1.0 / (base ** (torch.arange(0, actual_dim, 2, dtype=torch.float32) / actual_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + # Pre-compute for max_seq_len — slice in forward, never recreate. + t = torch.arange(max_seq_len, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self.register_buffer("_cos_cached", freqs.cos()[None, None, :, :], persistent=False) + self.register_buffer("_sin_cached", freqs.sin()[None, None, :, :], persistent=False) + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + # Slice pre-computed buffers — no tensor recreation, no dynamo guard failures. + return ( + self._cos_cached[:, :, :seq_len, :].to(device=device, dtype=dtype), + self._sin_cached[:, :, :seq_len, :].to(device=device, dtype=dtype), + ) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + """Apply rotary embeddings. If rope_dims > 0, only rotate first rope_dims dims (Partial RoPE, PR#315).""" + if rope_dims > 0 and rope_dims < x.size(-1): + # Partial RoPE: only rotate first rope_dims dims, keep rest unchanged + x_rot, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rot[..., :half], x_rot[..., half:] + x_rot = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rot, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + window_size: int = 0, + use_fa3: bool = False, + ntk_seq_len: int = 0, + max_seq_len: int = 2048, + use_xsa: bool = False, + rope_dims: int = 0, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + self.use_fa3 = use_fa3 and _FA3_AVAILABLE + self.use_xsa = use_xsa + self.rope_dims = rope_dims + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base, ntk_seq_len=ntk_seq_len, max_seq_len=max_seq_len, rope_dims=rope_dims) + self.window_size = window_size + self.num_anchor_tokens = 0 # set externally for hybrid equiv attention + + def forward(self, x: Tensor, ve: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + # Value Embeddings: add token-identity signal to values before attention. + if ve is not None: + v = v + ve.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + # Chunked linear attention fast path (for hybrid layers). + if getattr(self, 'use_linear_attn', False): + if self.num_kv_heads != self.num_heads: + n_rep = self.num_heads // self.num_kv_heads + k = k.unsqueeze(2).expand(-1, -1, n_rep, -1, -1).reshape(bsz, self.num_heads, seqlen, self.head_dim) + v = v.unsqueeze(2).expand(-1, -1, n_rep, -1, -1).reshape(bsz, self.num_heads, seqlen, self.head_dim) + y = self._chunked_linear_attn(q, k, v) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + # flash-attn direct-kernel fast path. In practice this is either: + # - FA3 Hopper kernels via flash_attn_interface / hopper.flash_attn_interface + # - FA2 kernels via flash_attn.flash_attn_interface + # Layout: (B, S, H, D). + if self.use_fa3 and self.num_anchor_tokens == 0: + q_fa = q.transpose(1, 2).to(torch.bfloat16) + k_fa = k.transpose(1, 2).to(torch.bfloat16) + v_fa = v.transpose(1, 2).to(torch.bfloat16) + fa_kw = dict(causal=True) + if self.window_size > 0: + fa_kw["window_size"] = (self.window_size, 0) + y = _fa3_func(q_fa, k_fa, v_fa, **fa_kw) + # XSA: project out self-value bias from attention output (PR#265, #315). + # GQA-aware: reshape to groups to avoid expensive repeat_interleave. + if self.use_xsa: + y_pre = y # (B, T, H, D) + Hkv = self.num_kv_heads + group_size = self.num_heads // Hkv + v_for_xsa = v_fa # already (B, T, Hkv, D) — FA format + vn = F.normalize(v_for_xsa, dim=-1).unsqueeze(-2) # (B, T, Hkv, 1, D) — bf16 matches PR#315 + y_grouped = y_pre.reshape(bsz, seqlen, Hkv, group_size, self.head_dim) + dot = (y_grouped * vn).sum(-1, keepdim=True) + y = (y_grouped - dot * vn).reshape(bsz, seqlen, dim) + else: + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + # Build attention mask. Dense causal (window_size=0, no anchors) uses fast is_causal=True. + # Sliding-window and/or hybrid-equivalence anchors require an explicit mask. + K = self.num_anchor_tokens + _use_gqa = self.num_kv_heads != self.num_heads + if self.window_size > 0 or K > 0: + idx = torch.arange(seqlen, device=x.device) + row, col = idx.unsqueeze(1), idx.unsqueeze(0) + future = col > row # base causal constraint + if self.window_size > 0 and K > 0: + outside_window = (row - col) >= self.window_size + not_anchor_col = col >= K + blocked = future | ((row >= K) & not_anchor_col & outside_window) + elif self.window_size > 0: + blocked = future | (row - col >= self.window_size) + else: + blocked = future + attn_mask = torch.zeros(seqlen, seqlen, device=x.device, dtype=q.dtype) + attn_mask = attn_mask.masked_fill(blocked, float("-inf")) + # Explicit mask requires manual GQA expansion (SDPA enable_gqa only works with is_causal) + if _use_gqa: + n_rep = self.num_heads // self.num_kv_heads + k = k.unsqueeze(2).expand(-1, -1, n_rep, -1, -1).reshape(bsz, self.num_heads, seqlen, self.head_dim) + v = v.unsqueeze(2).expand(-1, -1, n_rep, -1, -1).reshape(bsz, self.num_heads, seqlen, self.head_dim) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=False) + else: + # Dense causal: use native GQA (PyTorch 2.5+, avoids expensive k/v head expansion) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True, enable_gqa=_use_gqa) + # XSA for SDPA path: same GQA-aware projection (PR#265). + if self.use_xsa: + y_t = y.transpose(1, 2) # (B, T, H, D) + Hkv = self.num_kv_heads + group_size = self.num_heads // Hkv + # v was already expanded above; get original kv-head v from c_v + v_orig = self.c_v(x).reshape(bsz, seqlen, Hkv, self.head_dim) # (B, T, Hkv, D) + vn = F.normalize(v_orig, dim=-1).unsqueeze(-2) # (B, T, Hkv, 1, D) — bf16 matches PR#315 + y_grouped = y_t.reshape(bsz, seqlen, Hkv, group_size, self.head_dim) + dot = (y_grouped * vn).sum(-1, keepdim=True) + y = (y_grouped - dot * vn).reshape(bsz, seqlen, dim) + else: + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class ChebyshevMLP(nn.Module): + """MLP with learnable Chebyshev polynomial activation modulation. + Subsumes relu2, Star-ReLU as special cases. The optimizer discovers the optimal + nonlinearity for this model/data scale via order+1 learnable coefficients.""" + def __init__(self, dim: int, mlp_mult: int, order: int = 4): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.cheb_coeffs = nn.Parameter(torch.zeros(order + 1)) + self.cheb_coeffs.data[0] = 1.0 # init near identity modulation + + def forward(self, x: Tensor) -> Tensor: + h = self.fc(x) + t = torch.tanh(h) + T0, T1 = torch.ones_like(t), t + mod = self.cheb_coeffs[0] * T0 + self.cheb_coeffs[1] * T1 + for i in range(2, len(self.cheb_coeffs)): + T2 = 2 * t * T1 - T0 + mod = mod + self.cheb_coeffs[i] * T2 + T0, T1 = T1, T2 + return self.proj(torch.relu(h).square() * torch.sigmoid(mod)) + + +class StarReluMLP(nn.Module): + """Star-ReLU: relu^2 + learned per-dim affine (PR#505, MetaFormer). + The scale/bias let the optimizer control magnitude range and prevent dead neurons.""" + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.scale = nn.Parameter(torch.ones(hidden, dtype=torch.float32)) + self.bias = nn.Parameter(torch.zeros(hidden, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + h = torch.relu(self.fc(x)).square() + return self.proj(h * self.scale.to(dtype=h.dtype) + self.bias.to(dtype=h.dtype)) + + +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific deep layers (PR#505, #414).""" + def __init__(self, vocab_size: int, ve_dim: int, kv_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, kv_dim, bias=False) if ve_dim != kv_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class ChunkedCausalLinearAttention(nn.Module): + """Causal linear attention via chunked parallel computation. + + Splits sequence into chunks of size C. Within each chunk: O(C²) intra-chunk + attention in parallel across all chunks. Between chunks: additive KV state + propagated via exclusive prefix sum. All ops are batched matmuls + cumsum — + fully parallelizable, torch.compile(fullgraph=True) compatible. + + Complexity: O(L·C + L·d²/C) ≈ O(L·d) when C≈d. + For L=2048, C=64, d=64: ~8.5M ops/head vs ~268M for full softmax (31x cheaper). + """ + def __init__(self, chunk_size: int = 64): + super().__init__() + self.chunk_size = chunk_size + # Causal mask is registered as a buffer and resized if needed. + self.register_buffer("_causal_mask", torch.tril(torch.ones(chunk_size, chunk_size)), persistent=False) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + B, H, L, D = q.shape + V = v.shape[-1] + C = self.chunk_size + # Feature maps: ELU+1 ensures non-negativity for valid attention weights. + phi_q = F.elu(q, alpha=1.0) + 1 + phi_k = F.elu(k, alpha=1.0) + 1 + # Pad to exact multiple of chunk_size. + pad = (C - L % C) % C + if pad > 0: + phi_q = F.pad(phi_q, (0, 0, 0, pad)) + phi_k = F.pad(phi_k, (0, 0, 0, pad)) + v = F.pad(v, (0, 0, 0, pad)) + L_pad = phi_q.shape[2] + nc = L_pad // C + # Reshape into chunks: (B, H, nc, C, D). + phi_q = phi_q.reshape(B, H, nc, C, D) + phi_k = phi_k.reshape(B, H, nc, C, D) + v_c = v.reshape(B, H, nc, C, V) + # INTRA-CHUNK: quadratic attention within each chunk (parallel over all chunks). + intra_scores = torch.einsum('bhcid,bhcjd->bhcij', phi_q, phi_k) + causal_mask = self._causal_mask.to(dtype=q.dtype, device=q.device) + if causal_mask.shape[0] != C: + causal_mask = torch.tril(torch.ones(C, C, device=q.device, dtype=q.dtype)) + intra_scores = intra_scores * causal_mask[None, None, None] + intra_out = torch.einsum('bhcij,bhcjv->bhciv', intra_scores, v_c) + intra_norm = intra_scores.sum(dim=-1, keepdim=True) + # INTER-CHUNK: propagate accumulated KV state across chunks via exclusive prefix sum. + chunk_kv = torch.einsum('bhcid,bhciv->bhcdv', phi_k, v_c) # (B, H, nc, D, V) + chunk_k_sum = phi_k.sum(dim=3) # (B, H, nc, D) + kv_state = chunk_kv.cumsum(dim=2) - chunk_kv # exclusive: state BEFORE each chunk + k_state = chunk_k_sum.cumsum(dim=2) - chunk_k_sum + inter_out = torch.einsum('bhcid,bhcdv->bhciv', phi_q, kv_state) + inter_norm = torch.einsum('bhcid,bhcd->bhci', phi_q, k_state).unsqueeze(-1) + # COMBINE: sum intra + inter, normalize. + out = (intra_out + inter_out) / (intra_norm + inter_norm).clamp(min=1e-6) + out = out.reshape(B, H, L_pad, V) + if pad > 0: + out = out[:, :, :L, :] + return out + + +# --- Research extension: regularizer functions (gated by lambda=0 → no-op) --- + +def spectral_radius_penalty(h: Tensor, lambda_reg: float) -> Tensor: + """Penalize hidden-state anisotropy (narrow cone). Ref: Manny Ko, Jenison 1996.""" + if lambda_reg <= 0: + return torch.zeros((), device=h.device) + h_flat = F.normalize(h.reshape(-1, h.size(-1)), dim=-1) + if h_flat.size(0) > 256: # sample to bound compute + h_flat = h_flat[torch.randperm(h_flat.size(0), device=h.device)[:256]] + sim = h_flat @ h_flat.T + eye = torch.eye(sim.size(0), device=sim.device, dtype=sim.dtype) + return lambda_reg * ((sim - eye) ** 2).mean() + + +def tube_regularizer(h: Tensor, lambda_tube: float) -> Tensor: + """Penalize trajectory acceleration (semantic tube). Ref: Choi et al. arXiv:2602.22617.""" + if lambda_tube <= 0 or h.size(1) < 3: + return torch.zeros((), device=h.device) + vel = h[:, 1:] - h[:, :-1] + acc = vel[:, 1:] - vel[:, :-1] + return lambda_tube * (acc ** 2).mean() + + +def vicreg_regularizer(h: Tensor, lambda_vic: float) -> Tensor: + """VICReg-style variance-invariance-covariance regularizer (Bardes et al. 2022). + Forces hidden states toward isotropic Gaussian N(0,I). More principled than spectral penalty: + - Variance term: force per-dim std toward 1.0 (prevents collapse to a point) + - Covariance term: force off-diagonal covariance toward 0 (prevents narrow cone) + Adapted from LeWorldModel SIGReg for LM hidden states.""" + if lambda_vic <= 0: + return torch.zeros((), device=h.device) + h_flat = h.reshape(-1, h.size(-1)) # (B*T, D) + if h_flat.size(0) > 512: # sample to bound compute + h_flat = h_flat[torch.randperm(h_flat.size(0), device=h.device)[:512]] + # Variance: force std dev of each dimension toward 1.0 + std = h_flat.std(dim=0) + var_loss = F.relu(1.0 - std).mean() + # Covariance: force off-diagonal covariance toward 0 (decorrelation) + h_centered = h_flat - h_flat.mean(dim=0, keepdim=True) + N = h_centered.size(0) + cov = (h_centered.T @ h_centered) / max(N - 1, 1) # (D, D) + D = cov.size(0) + cov_loss = (cov.square().sum() - cov.diagonal().square().sum()) / max(D * (D - 1), 1) + return lambda_vic * (var_loss + cov_loss) + + +def _trajectory_metrics_impl(h: Tensor) -> dict[str, float]: + """STP-style diagnostics for hidden-state trajectories. + + Reports: + - drift_cos: cosine alignment of successive hidden states + - curvature: second-order finite-difference energy + - isotropy: mean eigenvalue / max eigenvalue of sample covariance + """ + with torch.no_grad(): + h = h.float().cpu() + if h.size(1) < 2: + return {"drift_cos": 0.0, "curvature": 0.0, "isotropy": 0.0} + if h.size(0) > 4: + h = h[:4] + if h.size(1) > 256: + h = h[:, :256] + h1 = F.normalize(h[:, :-1], dim=-1) + h2 = F.normalize(h[:, 1:], dim=-1) + drift = (h1 * h2).sum(dim=-1).mean() + if h.size(1) >= 3: + curv = ((h[:, 2:] - 2 * h[:, 1:-1] + h[:, :-2]) ** 2).mean() + else: + curv = torch.zeros((), device=h.device) + h_flat = h.reshape(-1, h.size(-1)) + if h_flat.size(0) > 512: + h_flat = h_flat[torch.randperm(h_flat.size(0), device=h.device)[:512]] + h_flat = h_flat - h_flat.mean(dim=0, keepdim=True) + cov = (h_flat.T @ h_flat) / max(h_flat.size(0) - 1, 1) + eigvals = torch.linalg.eigvalsh(cov).clamp_min(0) + isotropy = (eigvals.mean() / eigvals.max().clamp_min(1e-8)) + return { + "drift_cos": float(drift.item()), + "curvature": float(curv.item()), + "isotropy": float(isotropy.item()), + } + + +try: + import torch._dynamo as _torch_dynamo + trajectory_metrics = _torch_dynamo.disable(_trajectory_metrics_impl) +except Exception: + trajectory_metrics = _trajectory_metrics_impl + + +# --- Research extension: depth-routed weight sharing (Liquid Foundation Models) --- + +class DepthRouter(nn.Module): + """Tiny MLP that modulates hidden states based on logical depth fraction.""" + def __init__(self, model_dim: int, hidden: int = 0): + super().__init__() + hidden = hidden or model_dim // 4 + self.net = nn.Sequential( + nn.Linear(1, hidden, bias=False), nn.SiLU(), + nn.Linear(hidden, model_dim, bias=False)) + + def forward(self, x: Tensor, depth_frac: float) -> Tensor: + d = torch.tensor([depth_frac], device=x.device, dtype=x.dtype) + return x * torch.sigmoid(self.net(d))[None, None, :] + + +# --- Research extension: factorized output head (Yang et al. 1711.03953) --- + +class FactorizedHead(nn.Module): + """Low-rank projection: d_model → rank_dim → vocab_size.""" + def __init__(self, model_dim: int, vocab_size: int, rank_dim: int): + super().__init__() + self.down = CastedLinear(model_dim, rank_dim, bias=False) + self.up = CastedLinear(rank_dim, vocab_size, bias=False) + self.up._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + return self.up(self.down(x)) + + +class MixtureSoftmaxHead(nn.Module): + """Higher-rank output head via a gated mixture of low-rank experts. + + This targets the softmax rank bottleneck without paying for K full vocab + projections from the model dimension. + """ + + def __init__(self, model_dim: int, vocab_size: int, num_experts: int, rank_dim: int): + super().__init__() + self.num_experts = num_experts + self.gate = CastedLinear(model_dim, num_experts, bias=False) + self.down = nn.ModuleList([CastedLinear(model_dim, rank_dim, bias=False) for _ in range(num_experts)]) + self.up = nn.ModuleList([CastedLinear(rank_dim, vocab_size, bias=False) for _ in range(num_experts)]) + for layer in self.up: + layer._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + gate = F.softmax(self.gate(x), dim=-1) # (N, K) + logits = None + for i in range(self.num_experts): + expert_logits = self.up[i](self.down[i](x)) + weight = gate[:, i:i+1] + logits = expert_logits * weight if logits is None else logits + expert_logits * weight + return logits + + +class SimplexHead(nn.Module): + """Project hidden states onto a learned probability simplex before logits.""" + + def __init__(self, model_dim: int, vocab_size: int, simplex_dim: int): + super().__init__() + self.to_simplex = CastedLinear(model_dim, simplex_dim, bias=False) + self.from_simplex = CastedLinear(simplex_dim, vocab_size, bias=False) + self.from_simplex._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + simplex = F.softmax(self.to_simplex(x), dim=-1) + return self.from_simplex(simplex) + + +class HypergraphLift(nn.Module): + """Graded-poset-inspired multi-scale hypergraph lifting over token chunks. + + We treat interval chunks at multiple scales as hyperedges, build overlap-aware + node summaries, then broadcast them back to tokens as a structural residual. + """ + + def __init__(self, model_dim: int, scales: list[int]): + super().__init__() + self.scales = [s for s in scales if s > 1] + self.node_proj = nn.ModuleDict({str(s): CastedLinear(model_dim, model_dim, bias=False) for s in self.scales}) + self.mix_proj = CastedLinear(model_dim, model_dim, bias=False) + self.mix_proj._zero_init = True + self.gates = nn.Parameter(torch.zeros(len(self.scales), dtype=torch.float32)) + + def _chunk_mean(self, x: Tensor, chunk: int) -> tuple[Tensor, int]: + B, T, D = x.shape + pad = (chunk - (T % chunk)) % chunk + if pad: + x = F.pad(x, (0, 0, 0, pad)) + Tpad = x.size(1) + nodes = x.view(B, Tpad // chunk, chunk, D).mean(dim=2) + return nodes, pad + + def forward(self, x: Tensor) -> Tensor: + if not self.scales: + return torch.zeros_like(x) + B, T, D = x.shape + outputs = [] + for idx, chunk in enumerate(self.scales): + nodes, pad = self._chunk_mean(x, chunk) # (B, N, D) + # Overlap geometry: left/right intersections become extra hypernodes. + left = F.pad(nodes[:, :-1], (0, 0, 1, 0)) + right = F.pad(nodes[:, 1:], (0, 0, 0, 1)) + overlaps = 0.5 * (left + right) + parent = F.avg_pool1d(nodes.transpose(1, 2), kernel_size=2, stride=1, padding=1).transpose(1, 2)[:, :nodes.size(1)] + poset_nodes = self.node_proj[str(chunk)](nodes + overlaps + parent) + broadcast = poset_nodes[:, :, None, :].expand(-1, -1, chunk, -1).reshape(B, -1, D) + if pad: + broadcast = broadcast[:, :T] + outputs.append(torch.sigmoid(self.gates[idx]).to(dtype=x.dtype) * broadcast) + lifted = torch.stack(outputs, dim=0).sum(dim=0) / math.sqrt(len(outputs)) + return self.mix_proj(lifted) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + window_size: int = 0, + attention_type: str = "standard", + mlp_type: str = "standard", + activation_type: str = "relu2", + chebyshev_order: int = 4, + use_fa3: bool = False, + ntk_seq_len: int = 0, + max_seq_len: int = 2048, + use_xsa: bool = False, + rope_dims: int = 0, + ln_scale_factor: float = 1.0, + residual_mag_norm: bool = False, + token_saliency: bool = False, + saliency_chunk_sizes: str = "", + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.ln_scale_factor = ln_scale_factor # LN Scale: 1/sqrt(layer_idx+1) (PR#315) + if _RESEARCH and attention_type != "standard": + attn_cls = get_attention_cls(attention_type) + self.attn = attn_cls(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, window_size) + else: + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, window_size, use_fa3=use_fa3, ntk_seq_len=ntk_seq_len, max_seq_len=max_seq_len, use_xsa=use_xsa, rope_dims=rope_dims) + # MLP dispatch: Chebyshev > Star-ReLU > research modules > standard relu² + if activation_type == "chebyshev": + self.mlp = ChebyshevMLP(dim, mlp_mult, order=chebyshev_order) + elif activation_type == "star_relu": + self.mlp = StarReluMLP(dim, mlp_mult) + elif _RESEARCH and mlp_type != "standard": + mlp_cls = get_mlp_cls(mlp_type) + self.mlp = mlp_cls(dim, mlp_mult) + else: + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + # Residual magnitude normalization: keep residual RMS constant across layers. + if residual_mag_norm: + self.target_rms = nn.Parameter(torch.ones(1, dtype=torch.float32)) + # Token saliency routing: soft gate for adaptive computation per token. + if token_saliency: + self.saliency_gate = nn.Linear(dim, 1, bias=True) + nn.init.constant_(self.saliency_gate.bias, 3.0) # sigmoid(3)≈0.95: most tokens pass initially + self.saliency_chunk_sizes = [int(s) for s in saliency_chunk_sizes.split(",") if s.strip()] + if self.saliency_chunk_sizes: + self.saliency_chunk_gates = nn.ModuleList(nn.Linear(dim, 1, bias=False) for _ in self.saliency_chunk_sizes) + + def forward(self, x: Tensor, x0: Tensor, ve: Tensor | None = None) -> Tensor: + # Token saliency: save input before computation for soft-blend skip. + _saliency_gate = None + if hasattr(self, 'saliency_gate'): + saliency_logits = self.saliency_gate(F.rms_norm(x, (x.size(-1),))) + if hasattr(self, 'saliency_chunk_gates'): + B, T, D = x.shape + for chunk, proj in zip(self.saliency_chunk_sizes, self.saliency_chunk_gates): + if chunk <= 1: + continue + pad = (chunk - (T % chunk)) % chunk + x_pad = F.pad(x, (0, 0, 0, pad)) if pad else x + pooled = x_pad.view(B, x_pad.size(1) // chunk, chunk, D).mean(dim=2) + chunk_logits = proj(F.rms_norm(pooled, (pooled.size(-1),))) + chunk_logits = chunk_logits[:, :, None, :].expand(-1, -1, chunk, -1).reshape(B, -1, 1) + saliency_logits = saliency_logits + chunk_logits[:, :T] + _saliency_gate = torch.sigmoid(saliency_logits) # (B, T, 1) + x_pre_saliency = x + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + # LN Scale (PR#315): damp deeper layer norms by 1/sqrt(layer_idx+1). + s = self.ln_scale_factor + attn_out = self.attn(self.attn_norm(x) * s if s != 1.0 else self.attn_norm(x), ve=ve) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + mlp_in = self.mlp_norm(x) * s if s != 1.0 else self.mlp_norm(x) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(mlp_in) + # Residual magnitude normalization: normalize RMS to learnable target. + if hasattr(self, 'target_rms'): + rms = x.float().square().mean(dim=-1, keepdim=True).sqrt().clamp(min=1e-8) + x = x * (self.target_rms.to(dtype=x.dtype) / rms) + # Token saliency: soft-blend skip (gate=1 → full compute, gate→0 → skip). + if _saliency_gate is not None: + x = x_pre_saliency + _saliency_gate * (x - x_pre_saliency) + return x + + +class QATLinear(nn.Module): + """Quantization-Aware Training wrapper with Straight-Through Estimator. + Simulates the quantization grid during forward pass so the model learns + to work with quantized weights natively. STE passes gradients through rounding.""" + def __init__(self, linear: nn.Module, bits: int = 8): + super().__init__() + self.linear = linear + self.bits = bits + self.max_val = (1 << (bits - 1)) - 1 # 127 for 8-bit, 15 for 6-bit + self.enabled = True # Can be disabled for delayed QAT (QAT_START_FRAC) + + @property + def weight(self): + return self.linear.weight + + @property + def in_features(self): + return self.linear.in_features + + @property + def out_features(self): + return self.linear.out_features + + def forward(self, x: Tensor) -> Tensor: + if not self.enabled: + return self.linear(x) + w = self.linear.weight + # Per-row scale (matches post-training quantization scheme, PR#315 scale floor) + scale = (w.abs().amax(dim=-1, keepdim=True) / self.max_val).clamp_min(1.0 / self.max_val) + # Quantize-then-dequantize (fake quantization) — asymmetric range [-32,31] for int6 (PR#315) + w_q = (w / scale).round().clamp(-self.max_val - 1, self.max_val) * scale + # STE: forward uses quantized weights, backward treats rounding as identity + w_ste = w + (w_q - w).detach() + return F.linear(x, w_ste, getattr(self.linear, 'bias', None)) + + +def get_bigram_hash(x: Tensor, bigram_vocab_size: int) -> Tensor: + """Hash consecutive token pairs into bigram bucket indices (SOTA: PR#198, #194).""" + rand_int_1, rand_int_2 = 36313, 27191 + mod = bigram_vocab_size - 1 + x32 = x.to(torch.int32) + out = x32.clone() + out[..., 0] = mod # position 0 → sentinel bucket + out[..., 1:] = torch.bitwise_xor(rand_int_1 * x32[..., 1:], rand_int_2 * x32[..., :-1]) % mod + return out.long() + + +class BigramHashGate(nn.Module): + """Legacy parameter-free bigram gate (deprecated, kept for backward compat).""" + def __init__(self, d_model: int, hash_bins: int = 4096): + super().__init__() + self.hash_bins = hash_bins + self.prime = 31 + position = torch.arange(hash_bins).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) + pe = torch.zeros(hash_bins, d_model) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + self.register_buffer('hash_table', pe) + + def forward(self, hidden: Tensor, input_ids: Tensor) -> Tensor: + padded = F.pad(input_ids, (1, 0), value=0) + hashes = (padded[:, 1:] * self.prime + padded[:, :-1]) % self.hash_bins + return hidden * self.hash_table[hashes.long()] + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_physical_layers: int = 0, + window_size: int = 0, + num_anchor_tokens: int = 0, + liquid_depth: int = 0, + router_hidden: int = 0, + rank_dim: int = 0, + lambda_spectral: float = 0.0, + lambda_tube: float = 0.0, + lambda_vicreg: float = 0.0, + attention_type: str = "standard", + mlp_type: str = "standard", + use_qtt: bool = False, + tt_rank: int = 8, + num_cores: int = 12, + use_hash_embedding: bool = False, + liquid_type: str = "standard", + init_type: str = "standard", + loss_type: str = "cross_entropy", + latent_dim: int = 256, + patch_size: int = 1, + use_loop_embed: bool = False, + use_bigram_gate: bool = False, + bigram_hash_bins: int = 4096, + smear_gate: bool = False, + bigram_hash: bool = False, + bigram_vocab_size: int = 2048, + bigram_dim: int = 128, + use_fa3: bool = False, + ntk_seq_len: int = 0, + max_seq_len: int = 2048, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + backout_connection: bool = False, + backout_layer: int = 5, + # Sweep 006 novel techniques. + activation_type: str = "relu2", + chebyshev_order: int = 4, + residual_mag_norm: bool = False, + token_saliency: bool = False, + saliency_chunk_sizes: str = "", + value_embed: bool = False, + ve_dim: int = 128, + ve_layers: str = "", + attention_pattern: str = "standard", + softmax_anchor_layers: str = "", + linear_chunk_size: int = 64, + hypergraph_lift: bool = False, + hypergraph_layers: str = "", + hypergraph_scales: str = "2,4,8", + head_type: str = "standard", + mixture_softmax_k: int = 4, + mixture_rank_dim: int = 128, + simplex_dim: int = 256, + log_tube_metrics: bool = False, + ): + super().__init__() + self.init_type = init_type + self.loss_type = loss_type + self.patch_size = patch_size + self.num_layers = num_layers + self.log_tube_metrics = log_tube_metrics + self.last_aux_metrics: dict[str, float] = {} + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + # num_physical_layers < num_layers enables weight tying: blocks are cycled. + if num_physical_layers <= 0: + num_physical_layers = num_layers + self.num_physical_layers = num_physical_layers + self.use_loop_embed = use_loop_embed + # Per-loop depth embeddings for extreme depth recurrence (Universal Transformer). + # When weight tying is active, these tell the shared block which depth it's simulating. + if use_loop_embed and num_physical_layers < num_layers: + self.loop_embed = nn.Parameter(torch.randn(num_layers, 1, 1, model_dim) * 0.02) + if use_bigram_gate: + self.bigram_gate = BigramHashGate(model_dim, bigram_hash_bins) + # SmearGate: per-dim learned gate blending current and previous token (SOTA: PR#198, #194). + self.use_smear_gate = smear_gate + if smear_gate: + self.smear_gate_param = nn.Parameter(torch.zeros(model_dim, dtype=torch.float32)) + # BigramHash: learned embedding table with per-layer lambda scaling (SOTA: PR#198, #194). + self.use_bigram_hash = bigram_hash + self.bigram_vocab_size = bigram_vocab_size + if bigram_hash: + bdim = bigram_dim if bigram_dim > 0 else model_dim + self.bigram_embed = nn.Embedding(bigram_vocab_size, bdim) + nn.init.zeros_(self.bigram_embed.weight) + if bdim != model_dim: + self.bigram_proj = nn.Linear(bdim, model_dim, bias=False) + self.bigram_proj._zero_init = True + else: + self.bigram_proj = None + self.bigram_lambdas = nn.Parameter(0.05 * torch.ones(num_layers)) + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.lambda_spectral = lambda_spectral + self.lambda_tube = lambda_tube + self.lambda_vicreg = lambda_vicreg + self.liquid_depth = liquid_depth + # Embedding dispatch: QTT, HashEmbedding, or standard nn.Embedding. + if use_qtt and _RESEARCH: + self.tok_emb = QuanticsLinear(vocab_size, model_dim, tt_rank, num_cores) + elif use_hash_embedding and _RESEARCH: + self.tok_emb = HashEmbedding(vocab_size, model_dim) + else: + self.tok_emb = nn.Embedding(vocab_size, model_dim) + # JEPA latent loss: frozen random orthogonal projection + trainable head. + if loss_type == "jepa_latent": + proj = torch.randn(vocab_size, latent_dim) + if latent_dim <= vocab_size: + proj = torch.linalg.qr(proj)[0] # Orthogonal rows for stability. + self.register_buffer("jepa_proj", proj) + self.jepa_head = CastedLinear(model_dim, latent_dim, bias=False) + # SP-4096: compress patch_size consecutive token embeddings into 1 vector. + if patch_size > 1: + self.patch_down = CastedLinear(model_dim * patch_size, model_dim, bias=False) + self.patch_up = CastedLinear(model_dim, model_dim * patch_size, bias=False) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + # Parse saliency skip layers and softmax anchor layers for hybrid attention. + _saliency_skip = os.environ.get("SALIENCY_SKIP_LAYERS", "") + _saliency_layers = set(int(x) for x in _saliency_skip.split(",") if x.strip()) + _anchor_set = set(int(x) for x in softmax_anchor_layers.split(",") if x.strip()) if softmax_anchor_layers else set() + # Safety: seq_len must be divisible by chunk_size for hybrid attention (avoids padding bugs). + _seq_len = int(os.environ.get("TRAIN_SEQ_LEN", "1024")) + if attention_pattern == "hybrid_linear" and _seq_len % linear_chunk_size != 0: + raise ValueError(f"TRAIN_SEQ_LEN={_seq_len} must be divisible by LINEAR_CHUNK_SIZE={linear_chunk_size}") + self.blocks = nn.ModuleList() + for i in range(num_physical_layers): + # Determine per-layer attention type for hybrid attention. + _layer_attn_type = attention_type + _layer_use_xsa = (i >= num_physical_layers - xsa_last_n) if xsa_last_n > 0 else False + # Token saliency: apply to all layers if no skip list, or only specified layers. + _layer_saliency = token_saliency and (not _saliency_layers or i in _saliency_layers) + block = Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + window_size, _layer_attn_type, mlp_type, + activation_type=activation_type, chebyshev_order=chebyshev_order, + use_fa3=use_fa3, ntk_seq_len=ntk_seq_len, max_seq_len=max_seq_len, + use_xsa=_layer_use_xsa, rope_dims=rope_dims, + ln_scale_factor=1.0 / math.sqrt(i + 1) if ln_scale else 1.0, + residual_mag_norm=residual_mag_norm, token_saliency=_layer_saliency, + saliency_chunk_sizes=saliency_chunk_sizes, + ) + # Hybrid linear/softmax: set linear attention on non-anchor layers. + if attention_pattern == "hybrid_linear" and i not in _anchor_set: + block.attn.use_linear_attn = True + block.attn._chunked_linear_attn = ChunkedCausalLinearAttention(chunk_size=linear_chunk_size) + block.attn.use_xsa = False # XSA incompatible with linear attention + self.blocks.append(block) + # Value Embeddings (VE128): reinject token identity into attention values at deep layers. + _ve_layer_set = set(int(x) for x in ve_layers.split(",") if x.strip()) if ve_layers else set() + self._ve_layer_set = _ve_layer_set + if value_embed and _ve_layer_set: + kv_dim = num_kv_heads * (model_dim // num_heads) + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in _ve_layer_set] + ) + else: + self.ve_shared = None + _hg_layer_set = set(int(x) for x in hypergraph_layers.split(",") if x.strip()) if hypergraph_layers else set() + self._hg_layer_set = _hg_layer_set + if hypergraph_lift and _hg_layer_set: + _hg_scales = [int(x) for x in hypergraph_scales.split(",") if x.strip()] + self.hypergraph = HypergraphLift(model_dim, _hg_scales) + else: + self.hypergraph = None + # Backout connection (PR#339): learned subtraction of mid-layer hidden state from output. + self.backout_connection = backout_connection + self.backout_layer = min(backout_layer, num_layers - 1) + if backout_connection: + self.backout_lambda = nn.Parameter(torch.zeros(model_dim, dtype=torch.float32)) + # Set hybrid equivalence attention anchors on all attention modules. + if num_anchor_tokens > 0: + for blk in self.blocks: + blk.attn.num_anchor_tokens = num_anchor_tokens + # Depth-routed weight sharing (Liquid Foundation Models). + self.depth_router = DepthRouter(model_dim, router_hidden) if liquid_depth > 0 else None + # Symplectic Hamiltonian recurrence (stable at extreme depth). + self.symplectic = SymplecticLiquidLayer(model_dim) if (liquid_type == "symplectic" and _RESEARCH) else None + self.final_norm = RMSNorm() + # Output head: factorized low-rank, tied embeddings, or standard linear. + self.factorized_head = FactorizedHead(model_dim, vocab_size, rank_dim) if rank_dim > 0 else None + self.mixture_head = None + self.simplex_head = None + if head_type == "mixture_softmax": + self.mixture_head = MixtureSoftmaxHead(model_dim, vocab_size, mixture_softmax_k, mixture_rank_dim) + elif head_type == "simplex": + self.simplex_head = SimplexHead(model_dim, vocab_size, simplex_dim) + self.lm_head = None if (tie_embeddings or rank_dim > 0) else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings and hasattr(self.tok_emb, 'weight'): + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + # OvertoneInit: scale weight std by 1/sqrt(L) for NTK-aware convergence. + if self.init_type == "overtone": + L = max(self.num_encoder_layers + self.num_decoder_layers, 1) + scale = 1.0 / math.sqrt(L) + with torch.no_grad(): + for module in self.blocks.modules(): + if isinstance(module, (nn.Linear, CastedLinear)): + module.weight.mul_(scale) + # OrthoInit: orthogonal init for 2D matrices + projection scaling (SOTA: PR#198). + if self.init_type == "ortho": + with torch.no_grad(): + for name, module in self.named_modules(): + if isinstance(module, (nn.Linear, CastedLinear)) and module.weight.ndim == 2: + if module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj" in name or name.endswith(".proj"): + module.weight.mul_(1.0 / math.sqrt(2 * self.num_layers)) + + def _apply_smear_gate(self, x: Tensor) -> Tensor: + """Blend each token embedding with previous token's embedding (per-dim gate, SOTA: PR#198).""" + g = torch.sigmoid(self.smear_gate_param.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + # SmearGate: apply BEFORE rms_norm (SOTA ordering from PR#198, #194). + if self.use_smear_gate: + x = self._apply_smear_gate(x) + # BigramHashGate (legacy): apply before patching so input_ids and x have matching seq length. + if hasattr(self, 'bigram_gate'): + x = self.bigram_gate(x, input_ids) + # SP-4096: compress patch_size consecutive embeddings into one vector. + if self.patch_size > 1: + B, T, D = x.shape + x = x.reshape(B, T // self.patch_size, D * self.patch_size) # (B, T/P, D*P) + x = self.patch_down(x) # (B, T/P, D) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + # BigramHash (SOTA): compute once, inject per-layer. + if self.use_bigram_hash: + bigram_emb = self.bigram_embed(get_bigram_hash(input_ids, self.bigram_vocab_size)) + if self.bigram_proj is not None: + bigram_emb = self.bigram_proj(bigram_emb) + else: + bigram_emb = None + skips: list[Tensor] = [] + _backout_h: Tensor | None = None + total_layers = self.liquid_depth if self.liquid_depth > 0 else (self.num_encoder_layers + self.num_decoder_layers) + # Value Embeddings: compute shared VE once, cache per-layer scaled versions. + _ve_cache: dict[int, Tensor] = {} + if self.ve_shared is not None: + _ve_base = self.ve_shared(input_ids) + _ve_sorted = sorted(self._ve_layer_set) + for idx, layer_i in enumerate(_ve_sorted): + _ve_cache[layer_i] = _ve_base * self.ve_layer_scales[idx].to(dtype=_ve_base.dtype) + + # First half stores skips; second half reuses them in reverse order. + # When num_physical_layers < num_layers, blocks are cycled (weight tying). + for i in range(self.num_encoder_layers): + if bigram_emb is not None: + x = x + bigram_emb * self.bigram_lambdas[i].to(dtype=x.dtype) + if self.depth_router is not None: + x = self.depth_router(x, i / total_layers) + if self.use_loop_embed and hasattr(self, 'loop_embed'): + x = x + self.loop_embed[i] + x = self.blocks[i % self.num_physical_layers](x, x0, ve=_ve_cache.get(i)) + if self.hypergraph is not None and i in self._hg_layer_set: + x = x + self.hypergraph(x) + if self.backout_connection and i == self.backout_layer: + _backout_h = x + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + _layer_idx = self.num_encoder_layers + i + if bigram_emb is not None: + x = x + bigram_emb * self.bigram_lambdas[_layer_idx].to(dtype=x.dtype) + if self.depth_router is not None: + x = self.depth_router(x, _layer_idx / total_layers) + if self.use_loop_embed and hasattr(self, 'loop_embed'): + x = x + self.loop_embed[_layer_idx] + x = self.blocks[_layer_idx % self.num_physical_layers](x, x0, ve=_ve_cache.get(_layer_idx)) + if self.hypergraph is not None and _layer_idx in self._hg_layer_set: + x = x + self.hypergraph(x) + if self.backout_connection and _layer_idx == self.backout_layer: + _backout_h = x + + # Backout connection (PR#339): subtract learned fraction of mid-layer hidden state. + if _backout_h is not None: + x = x - self.backout_lambda.to(dtype=x.dtype)[None, None, :] * _backout_h + + # Symplectic Hamiltonian recurrence (applied after standard blocks). + if self.symplectic is not None: + _symp_steps = self.liquid_depth if self.liquid_depth > 0 else 24 + x = x + self.symplectic(x, num_steps=_symp_steps) + + x_pre_head = x # preserve for regularizers (before reshape) + if self.log_tube_metrics and not self.training: + self.last_aux_metrics = trajectory_metrics(x_pre_head.detach()) + # SP-4096: expand back to original token-level resolution before output head. + if self.patch_size > 1: + x_expanded = self.patch_up(x) # (B, T/P, D*P) + model_dim = x.size(-1) + x = x_expanded.reshape(x_expanded.shape[0], -1, model_dim) # (B, T, D) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + + # JEPA latent loss: L2 in frozen random latent space (training only). + # Validation always uses standard cross-entropy for accurate BPB. + if hasattr(self, 'jepa_proj') and self.training: + pred_latent = self.jepa_head(x) # (B*T, latent_dim) + target_latent = self.jepa_proj[targets] # (B*T, latent_dim) — index into frozen proj + loss = F.mse_loss(pred_latent, target_latent) + else: + if self.mixture_head is not None: + logits_proj = self.mixture_head(x) + elif self.simplex_head is not None: + logits_proj = self.simplex_head(x) + elif self.factorized_head is not None: + logits_proj = self.factorized_head(x) + elif self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + elif self.lm_head is not None: + logits_proj = self.lm_head(x) + else: + raise RuntimeError("No output head configured") + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + loss = F.cross_entropy(logits.float(), targets, reduction="mean") + # Research regularizers (no-op when lambda=0). + loss = loss + spectral_radius_penalty(x_pre_head.float(), self.lambda_spectral) + loss = loss + tube_regularizer(x_pre_head.float(), self.lambda_tube) + loss = loss + vicreg_regularizer(x_pre_head.float(), self.lambda_vicreg) + return loss + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if not args.skip_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + _base_accum = int(os.environ.get("GRAD_ACCUM_STEPS", str(8 // world_size))) + grad_accum_steps = _base_accum + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + if not args.dev_mode: + raise RuntimeError("CUDA is required (set DEV_MODE=1 for local MPS/CPU testing)") + device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") + else: + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs (CUDA only) + if device.type == "cuda": + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + # Enable cuDNN SDP: Hopper-optimized attention kernel on H100 (cuDNN 9+). + # PR#505 achieves 74ms/step vs our 112ms — this may close the gap. + enable_cudnn_sdp(True) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + def sync() -> None: + if device.type == "cuda": + torch.cuda.synchronize() + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + if device.type == "cuda": + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if device.type == "cuda": + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_physical_layers=args.num_physical_layers, + window_size=args.window_size, + num_anchor_tokens=args.num_anchor_tokens, + liquid_depth=args.liquid_depth, + router_hidden=args.router_hidden, + rank_dim=args.rank_dim, + lambda_spectral=args.lambda_spectral, + lambda_tube=args.lambda_tube, + lambda_vicreg=args.lambda_vicreg, + attention_type=args.attention_type, + mlp_type=args.mlp_type, + use_qtt=args.use_qtt, + tt_rank=args.tt_rank, + num_cores=args.num_cores, + use_hash_embedding=args.use_hash_embedding, + liquid_type=args.liquid_type, + init_type=args.init_type, + loss_type=args.loss_type, + latent_dim=args.latent_dim, + patch_size=args.patch_size, + use_loop_embed=args.use_loop_embed, + use_bigram_gate=args.use_bigram_gate, + bigram_hash_bins=args.bigram_hash_bins, + smear_gate=args.smear_gate, + bigram_hash=args.bigram_hash, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + use_fa3=args.use_fa3, + ntk_seq_len=args.train_seq_len if args.use_ntk_rope else 0, + max_seq_len=args.train_seq_len, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + backout_connection=args.backout_connection, + backout_layer=args.backout_layer, + activation_type=args.activation_type, + chebyshev_order=args.chebyshev_order, + residual_mag_norm=args.residual_mag_norm, + token_saliency=args.token_saliency, + saliency_chunk_sizes=args.saliency_chunk_sizes, + value_embed=args.value_embed, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + attention_pattern=args.attention_pattern, + softmax_anchor_layers=args.softmax_anchor_layers, + linear_chunk_size=args.linear_chunk_size, + hypergraph_lift=args.hypergraph_lift, + hypergraph_layers=args.hypergraph_layers, + hypergraph_scales=args.hypergraph_scales, + head_type=args.head_type, + mixture_softmax_k=args.mixture_softmax_k, + mixture_rank_dim=args.mixture_rank_dim, + simplex_dim=args.simplex_dim, + log_tube_metrics=args.log_tube_metrics, + ).to(device).bfloat16() + # Ternary quantization: monkey-patch CastedLinear before float conversion. + if args.use_ternary and _RESEARCH: + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.__class__ = TernaryLinear + # QAT: wrap CastedLinear layers with fake quantization (STE). + # Must recurse into block.attn and block.mlp — Block's direct children are sub-modules, not CastedLinear. + # Late QAT (PR#315): same wrapping but starts disabled, activated when lr_scale < qat_threshold. + if args.use_qat or args.late_qat: + _qat_count = 0 + _mlp_bits = args.quant_bits_mlp if args.quant_bits_mlp > 0 else args.qat_bits + for block in base_model.modules(): + if not isinstance(block, Block): + continue + for name, child in list(block.attn.named_children()): + if isinstance(child, CastedLinear): + setattr(block.attn, name, QATLinear(child, bits=args.qat_bits)) + _qat_count += 1 + for name, child in list(block.mlp.named_children()): + if isinstance(child, CastedLinear): + setattr(block.mlp, name, QATLinear(child, bits=_mlp_bits)) + _qat_count += 1 + log0(f"QAT: attn={args.qat_bits}b mlp={_mlp_bits}b on {_qat_count} CastedLinear layers") + if args.late_qat: + # Late QAT (PR#315): start disabled, activate when lr_scale < qat_threshold + for m in base_model.modules(): + if isinstance(m, QATLinear): + m.enabled = False + log0(f"Late QAT: activates when lr_scale < {args.qat_threshold}") + elif args.qat_start_frac > 0: + for m in base_model.modules(): + if isinstance(m, QATLinear): + m.enabled = False + log0(f"QAT deferred: activates at {args.qat_start_frac*100:.0f}% of training") + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + # Skip torch.compile in dev mode (MPS/CPU) or when ALLMEM uses in-forward autograd.grad(). + _skip_compile = args.dev_mode or args.skip_compile or args.attention_type == "allmem" + # Context curriculum changes seq_len mid-training → dynamic shapes needed. + # Rotary buffers are pre-allocated (no object ID changes), but input tensor shapes change. + _use_dynamic = bool(args.seq_curriculum) + if _use_dynamic: + # Allow enough recompiles for curriculum stages (3 stages = 3 recompiles, not the default 8 limit). + torch._dynamo.config.cache_size_limit = 64 + if distributed: + # DDP optimizer chokes on higher-order ops (flash_attn) in both dynamic and fullgraph modes. + torch._dynamo.config.optimize_ddp = False + if not _skip_compile: + try: + compiled_model = torch.compile( + base_model, + dynamic=_use_dynamic, + fullgraph=not _use_dynamic, # fullgraph incompatible with dynamic shapes + ) + if _use_dynamic: + log0(f"torch.compile: dynamic=True, cache_size_limit={torch._dynamo.config.cache_size_limit}, optimize_ddp={torch._dynamo.config.optimize_ddp}") + except Exception as e: + print(f"WARNING: torch.compile failed ({e}), falling back to eager mode", file=sys.stderr) + compiled_model = base_model + else: + compiled_model = base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + eval_model: nn.Module = base_model if args.log_tube_metrics else model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + # SOTA features: add SmearGate, BigramHash params to optimizer (scalar/Adam group). + if base_model.use_smear_gate: + scalar_params.append(base_model.smear_gate_param) + if base_model.backout_connection: + scalar_params.append(base_model.backout_lambda) + if base_model.use_bigram_hash: + scalar_params.append(base_model.bigram_lambdas) + # BigramHash: embed.weight → AdamW (tok group), proj.weight → Muon (matrix group) (PR#162) + _extra_embed_params: list[nn.Parameter] = [] + if base_model.use_bigram_hash: + _extra_embed_params.append(base_model.bigram_embed.weight) + if base_model.bigram_proj is not None: + matrix_params.append(base_model.bigram_proj.weight) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + _tok_param_groups = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if _extra_embed_params: + _tok_param_groups.append({"params": _extra_embed_params, "lr": token_lr, "base_lr": token_lr}) + optimizer_tok = torch.optim.AdamW( + _tok_param_groups, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_weight_decay, + fused=(device.type == "cuda"), + ) + if args.optimizer_type == "conda" and _RESEARCH: + optimizer_muon = CondaOptimizer(matrix_params, lr=args.matrix_lr, betas=(args.beta1, args.beta2)) + elif args.optimizer_type == "matrix_free_cg" and _RESEARCH: + _cg_iters = int(os.environ.get("CG_ITERS", "5")) + optimizer_muon = MatrixFreeCGOptimizer(matrix_params, lr=args.matrix_lr, cg_iters=_cg_iters) + elif args.optimizer_type == "qes" and _RESEARCH: + _qes_pop = int(os.environ.get("QES_POP_SIZE", "16")) + _qes_sigma = float(os.environ.get("QES_SIGMA", "0.05")) + qes_optimizer = QESOptimizer(base_model, lr=args.matrix_lr, pop_size=_qes_pop, sigma=_qes_sigma) + # Still need a placeholder for the optimizers list (used for lr scheduling) + optimizer_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, backend_steps=args.muon_backend_steps, variant=args.optimizer_variant, weight_decay=args.muon_weight_decay) + log0(f"QES optimizer: forward-only mode, pop_size={_qes_pop}, sigma={_qes_sigma}") + else: + optimizer_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, backend_steps=args.muon_backend_steps, variant=args.optimizer_variant, weight_decay=args.muon_weight_decay) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_weight_decay, + fused=(device.type == "cuda"), + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.AdamW( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_weight_decay, + fused=(device.type == "cuda"), + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"sdp_backends:cudnn=True flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device, curriculum=args.data_curriculum) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device, curriculum=args.data_curriculum) + + # ----------------------------- + # SWA (Stochastic Weight Averaging) / EMA (Exponential Moving Average) + # ----------------------------- + _swa_state: dict | None = None + _swa_count = 0 + _ema_state: dict | None = None + _ema_started = False + _use_ema = args.use_ema # flag checked in training loop; allocation deferred to first collection step + # PR#315: EMA from step 0 — init shadow weights immediately, update every step (no guards) + if _use_ema and args.ema_from_init: + _ema_state = {k: v.detach().float().clone() for k, v in base_model.state_dict().items()} + _ema_started = True + log0(f"ema:initialized from model init (alpha={args.ema_alpha})") + if not _use_ema and args.use_swa: + _swa_state = {k: torch.zeros_like(v) for k, v in base_model.state_dict().items()} + + # Gradient-guided quant: accumulate gradient magnitudes during warmdown for adaptive bit assignment. + _grad_sensitivity: dict[str, float] = {} + _grad_accum_started = False + + # Spectral warmdown: per-layer LR scaling factors (computed once at warmdown start). + _spectral_lr_scales: dict[int, float] = {} # param id -> scale factor + _spectral_computed = False + + # ----------------------------- + # CONTEXT CURRICULUM (PR#203: train at shorter seq_len first for more steps) + # ----------------------------- + _seq_curriculum_stages: list[tuple[int, float]] = [] + if args.seq_curriculum: + for stage in args.seq_curriculum.split(","): + slen_str, frac_str = stage.strip().split(":") + _seq_curriculum_stages.append((int(slen_str), float(frac_str))) + _seq_curriculum_stages.sort(key=lambda x: x[1]) + log0(f"seq_curriculum: {_seq_curriculum_stages}") + + def _get_seq_len(frac: float) -> int: + """Return current sequence length based on training progress fraction.""" + if not _seq_curriculum_stages: + return args.train_seq_len + for slen, threshold in _seq_curriculum_stages: + if frac < threshold: + return slen + return args.train_seq_len # fallback to full seq_len + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + _qat_activated = not (args.use_qat or args.late_qat) or (args.use_qat and args.qat_start_frac <= 0.0 and not args.late_qat) + _prev_curriculum_seq_len = args.train_seq_len + sync() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + sync() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + eval_model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + _metrics_src = base_model if args.log_tube_metrics else (model.module if hasattr(model, "module") else model) + _aux = getattr(_metrics_src, "last_aux_metrics", None) + if args.log_tube_metrics and _aux: + log0( + f"trajectory_metrics: drift_cos:{_aux.get('drift_cos', 0.0):.4f} " + f"curvature:{_aux.get('curvature', 0.0):.6f} isotropy:{_aux.get('isotropy', 0.0):.4f}" + ) + sync() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if not _qat_activated: + # Late QAT (PR#315): activate when lr_scale drops below threshold + _should_activate_qat = False + if args.late_qat: + _should_activate_qat = scale < args.qat_threshold + elif args.use_qat and step / max(args.iterations, 1) >= args.qat_start_frac: + _should_activate_qat = True + if _should_activate_qat: + for m in base_model.modules(): + if isinstance(m, QATLinear): + m.enabled = True + _qat_activated = True + _qat_label = "Late QAT" if args.late_qat else "QAT" + log0(f"{_qat_label} activated at step {step} (lr_scale={scale:.4f})") + if args.optimizer_type == "qes" and _RESEARCH: + # QES: zeroth-order, forward-only. No backward pass needed. + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True): + qes_optimizer.step(None, x, y) # model called internally + with torch.no_grad(): + train_loss = compiled_model(x, y).detach() + else: + # Context curriculum: compute current seq_len based on wallclock progress. + if _seq_curriculum_stages and max_wallclock_ms: + _wc_frac = elapsed_ms / max_wallclock_ms + _curr_seq_len = _get_seq_len(_wc_frac) + if _curr_seq_len != _prev_curriculum_seq_len: + log0(f"seq_curriculum: switching to seq_len={_curr_seq_len} at step {step} ({_wc_frac*100:.0f}% wallclock)") + _prev_curriculum_seq_len = _curr_seq_len + else: + _curr_seq_len = args.train_seq_len + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, _curr_seq_len, grad_accum_steps) + with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + # Gradient-guided quant: accumulate grad magnitudes during warmdown for adaptive bit assignment. + if args.grad_guided_quant and scale < 1.0: + if not _grad_accum_started: + _grad_sensitivity = {n: 0.0 for n, p in base_model.named_parameters() if p.ndim >= 2} + _grad_accum_started = True + log0(f"grad_guided_quant: started accumulating ({len(_grad_sensitivity)} tensors)") + for n, p in base_model.named_parameters(): + if p.grad is not None and n in _grad_sensitivity: + _grad_sensitivity[n] += p.grad.detach().float().abs().mean().item() + + # Spectral warmdown: compute per-layer spectral norms once at warmdown start, scale LR. + if args.spectral_warmdown and scale < 1.0 and not _spectral_computed: + _spectral_norms: dict[int, float] = {} + for n, p in base_model.named_parameters(): + if p.ndim == 2 and p.shape[0] >= 64 and p.shape[1] >= 64: + with torch.no_grad(): + v = torch.randn(p.shape[1], device=p.device, dtype=p.dtype) + u = p @ v + s = u.norm().item() + _spectral_norms[id(p)] = max(s, 1e-8) + if _spectral_norms: + mean_s = sum(_spectral_norms.values()) / len(_spectral_norms) + _spectral_lr_scales = {pid: s / mean_s for pid, s in _spectral_norms.items()} + log0(f"spectral_warmdown: computed scales for {len(_spectral_lr_scales)} params " + f"(range {min(_spectral_lr_scales.values()):.3f}-{max(_spectral_lr_scales.values()):.3f})") + _spectral_computed = True + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + base_scale = group["base_lr"] * scale + if _spectral_lr_scales and scale < 1.0: + # Spectral warmdown: per-param LR scaling by spectral norm + for p in group["params"]: + if id(p) in _spectral_lr_scales: + p._spectral_scale = _spectral_lr_scales[id(p)] + # Use mean spectral scale for the group-level LR + scales_in_group = [_spectral_lr_scales[id(p)] for p in group["params"] if id(p) in _spectral_lr_scales] + if scales_in_group: + group["lr"] = base_scale * (sum(scales_in_group) / len(scales_in_group)) + else: + group["lr"] = base_scale + else: + group["lr"] = base_scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + if _RESEARCH and isinstance(opt, MatrixFreeCGOptimizer): + opt.step(loss) # CG needs the loss tensor for HVP computation + else: + opt.step() + zero_grad_all() + + # Cross-layer weight coherence regularizer: soft weight sharing penalty. + # Computed outside compiled model to avoid graph breaks from iterating named_parameters. + if args.lambda_coherence > 0 and step % 10 == 0: + _coh_loss = torch.zeros((), device=device) + _coh_n = 0 + for _bi in range(len(base_model.blocks) - 1): + for (_, _p1), (_, _p2) in zip(base_model.blocks[_bi].named_parameters(), base_model.blocks[_bi + 1].named_parameters()): + if _p1.shape == _p2.shape and _p1.ndim == 2: + _coh_loss = _coh_loss + (_p1 - _p2).square().mean() + _coh_n += 1 + if _coh_n > 0: + (args.lambda_coherence * _coh_loss / _coh_n).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + + # EMA: exponential moving average of weights — biases toward recent weights (PR#201, #223). + # Two modes: ema_from_init (PR#315: every step, no guards) vs delayed (original: warmdown only). + if _use_ema and _ema_started and args.ema_from_init: + # PR#315 mode: update every step, no step/scale guards + with torch.no_grad(): + alpha = args.ema_alpha + for k, v in base_model.state_dict().items(): + _ema_state[k].mul_(alpha).add_(v.detach().float(), alpha=1.0 - alpha) + elif _use_ema and not args.ema_from_init and step >= 200 and scale < args.ema_start_ratio: + with torch.no_grad(): + if not _ema_started: + # Allocate and initialize shadow weights to current model state (deferred to save memory). + _ema_state = {k: v.detach().float().clone() for k, v in base_model.state_dict().items()} + _ema_started = True + log0(f"ema:started step:{step} lr_scale:{scale:.4f} alpha:{args.ema_alpha}") + else: + alpha = args.ema_alpha + for k, v in base_model.state_dict().items(): + _ema_state[k].mul_(alpha).add_(v.detach().float(), alpha=1.0 - alpha) + # SWA: accumulate weights during warmdown phase at fixed intervals (SOTA: every 50 steps). + elif _swa_state is not None and scale < args.swa_start_ratio and step % args.swa_every == 0: + with torch.no_grad(): + for k, v in base_model.state_dict().items(): + _swa_state[k].add_(v) + _swa_count += 1 + if _swa_count == 1: + log0(f"swa:started step:{step} lr_scale:{scale:.4f}") + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + if device.type == "cuda": + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # EMA: load exponentially averaged weights back into model. + if _ema_state is not None and _ema_started: + log0(f"ema:loading shadow weights (alpha={args.ema_alpha})") + ema_state_typed = {k: v.to(dtype=base_model.state_dict()[k].dtype) + for k, v in _ema_state.items()} + base_model.load_state_dict(ema_state_typed, strict=True) + del _ema_state + # SWA: average accumulated weights and load back into model. + elif _swa_state is not None and _swa_count > 0: + log0(f"swa:averaging {_swa_count} checkpoints") + avg_state = {k: (v / _swa_count).to(dtype=base_model.state_dict()[k].dtype) + for k, v in _swa_state.items()} + base_model.load_state_dict(avg_state, strict=True) + del _swa_state # free memory + + # Magnitude pruning: zero out smallest weights for better compression (PR#205: 2%). + if args.prune_fraction > 0: + _sd = base_model.state_dict() + _sd_pruned = prune_state_dict(_sd, args.prune_fraction) + base_model.load_state_dict(_sd_pruned, strict=True) + log0(f"magnitude_pruning: zeroed {args.prune_fraction*100:.1f}% of 2D weight values") + + # ----------------------------- + # FULL-WEIGHT TEST-TIME TRAINING (PR#254 SOTA: 1.1303 bpb) + # ----------------------------- + # Adapt model weights on validation data via SGD before quantization. + # The adapted weights get quantized and submitted — effectively fine-tuning the model + # on the exact data it will be evaluated on. Allowed by competition rules. + if args.use_ttt and not args.skip_quant: + apply_full_weight_ttt( + base_model, val_tokens, device, args, + rank=rank, world_size=world_size, log_fn=log0, + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + # SKIP_QUANT=1 skips this section entirely (fast local dev / mini-runs). + if args.skip_quant: + if master_process: + log0("skip_quant=True: skipping serialization and roundtrip validation") + if distributed: + dist.destroy_process_group() + return + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # Gradient-guided quant: compute per-tensor bit assignments from accumulated gradient sensitivity. + _grad_bit_map = None + if args.grad_guided_quant and _grad_sensitivity: + _grad_bit_map = compute_grad_bit_map(_grad_sensitivity, quant_bits_middle=args.quant_bits_middle) + if master_process: + log0(f"grad_guided_quant: {len(_grad_bit_map)} tensors assigned adaptive bits") + _bit_counts = {} + for b in _grad_bit_map.values(): + _bit_counts[b] = _bit_counts.get(b, 0) + 1 + log0(f" bit distribution: {dict(sorted(_bit_counts.items()))}") + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict(), quant_bits_middle=args.quant_bits_middle, quant_bits_mlp=args.quant_bits_mlp, fp16_embeddings=args.fp16_embeddings, grad_bit_map=_grad_bit_map, error_diffusion=args.quant_error_diffusion, hadamard=args.quant_hadamard) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if args.compress_algo == "zstd": + import zstandard + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + _calgo = args.compress_algo + log0( + f"Serialized model int8+{_calgo}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+{_calgo}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if args.compress_algo == "zstd": + import zstandard + _decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + _decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(_decompressed), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + sync() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + eval_model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + sync() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + _final_metrics_src = base_model if args.log_tube_metrics else (model.module if hasattr(model, "module") else model) + _final_aux = getattr(_final_metrics_src, "last_aux_metrics", None) + if args.log_tube_metrics and _final_aux: + log0( + f"final_trajectory_metrics: drift_cos:{_final_aux.get('drift_cos', 0.0):.4f} " + f"curvature:{_final_aux.get('curvature', 0.0):.6f} isotropy:{_final_aux.get('isotropy', 0.0):.4f}" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main()