diff --git a/records/track_10min_16mb/2026-04-30_SP1024_ValueResid_PPMMix/ train.log b/records/track_10min_16mb/2026-04-30_SP1024_ValueResid_PPMMix/ train.log new file mode 100644 index 0000000000..a432146348 --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_SP1024_ValueResid_PPMMix/ train.log @@ -0,0 +1,92 @@ +output/run_sweep_record_try_sp1024_v1/record_try_sp1024_512d_mlp2_value_resid_last2_ppm_hi075_steps2200_20260501_051628/20260501_051631.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +Architecture: Discrete N-Gram Hash (Max N=2) +lora_params:0 +model_params:17385709 +world_size:1 grad_accum_steps:4 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True matrix_lr:0.04 scalar_lr:0.04 +ttt_enabled:False ttt_mode:lora lora_ttt_enabled:False +parallel_v2_enabled:0 mode:dual_add second_lane:mlp active_layers:[] second_lane_params:0 +W0501 05:17:04.134000 130660118263360 torch/fx/experimental/symbolic_shapes.py:4449] [0/0_1] q0 is not in var_ranges, defaulting to unknown range. +W0501 05:17:04.175000 130660118263360 torch/fx/experimental/symbolic_shapes.py:4449] [0/0_1] z0 is not in var_ranges, defaulting to unknown range. +W0501 05:17:06.797000 130660118263360 torch/fx/experimental/symbolic_shapes.py:4449] [0/0_1] x0 is not in var_ranges, defaulting to unknown range. +W0501 05:17:43.264000 130660118263360 torch/fx/experimental/symbolic_shapes.py:4449] [0/1] q0 is not in var_ranges, defaulting to unknown range. +W0501 05:17:43.279000 130660118263360 torch/fx/experimental/symbolic_shapes.py:4449] [0/1] z0 is not in var_ranges, defaulting to unknown range. +W0501 05:17:45.391000 130660118263360 torch/fx/experimental/symbolic_shapes.py:4449] [0/1] x0 is not in var_ranges, defaulting to unknown range. +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +EMA Enabled: decay=0.997 +Scheduled Late QAT to start at step 1870 (last 15.0%) +step:0/2200 val_loss:6.9311 val_bpb:4.1050 train_time:3ms step_avg:3.24ms +step:1/2200 train_loss:6.9310 train_time:5646ms step_avg:5645.96ms +step:2/2200 train_loss:6.7587 train_time:6260ms step_avg:3130.22ms +step:3/2200 train_loss:6.3280 train_time:6876ms step_avg:2291.92ms +step:4/2200 train_loss:6.0142 train_time:7491ms step_avg:1872.66ms +step:5/2200 train_loss:5.8538 train_time:8105ms step_avg:1621.03ms +step:6/2200 train_loss:5.7285 train_time:8720ms step_avg:1453.37ms +step:7/2200 train_loss:5.6112 train_time:9335ms step_avg:1333.56ms +step:8/2200 train_loss:5.5435 train_time:9950ms step_avg:1243.71ms +step:9/2200 train_loss:5.4271 train_time:10564ms step_avg:1173.82ms +step:10/2200 train_loss:5.3285 train_time:11179ms step_avg:1117.95ms +step:200/2200 train_loss:2.6610 train_time:127902ms step_avg:639.51ms +step:400/2200 train_loss:2.3472 train_time:250871ms step_avg:627.18ms +step:600/2200 train_loss:2.4710 train_time:373898ms step_avg:623.16ms +step:800/2200 train_loss:2.3273 train_time:496924ms step_avg:621.15ms +step:1000/2200 train_loss:2.3827 train_time:619883ms step_avg:619.88ms +step:1000/2200 val_loss:2.3569 val_bpb:1.3959 train_time:619884ms step_avg:619.88ms +step:1200/2200 train_loss:2.2925 train_time:742828ms step_avg:619.02ms +step:1400/2200 train_loss:2.3217 train_time:865823ms step_avg:618.44ms +step:1600/2200 train_loss:2.1926 train_time:988786ms step_avg:617.99ms +step:1800/2200 train_loss:2.2291 train_time:1111692ms step_avg:617.61ms +[Step 1870] Activating Late QAT — enabling branchless STE quantization. +step:2000/2200 train_loss:2.1763 train_time:1234584ms step_avg:617.29ms +step:2000/2200 val_loss:2.1922 val_bpb:1.2984 train_time:1234584ms step_avg:617.29ms +step:2200/2200 train_loss:2.1102 train_time:1357536ms step_avg:617.06ms +step:2200/2200 val_loss:2.1771 val_bpb:1.2894 train_time:1357536ms step_avg:617.06ms +peak memory allocated: 29237 MiB reserved: 30568 MiB +Applying EMA weights for final evaluation... +saved raw checkpoint: output/run_sweep_record_try_sp1024_v1/record_try_sp1024_512d_mlp2_value_resid_last2_ppm_hi075_steps2200_20260501_051628/final_model.pt (67,874,184 bytes) +model_size int8+zlib:15650103 bytes code:156032 bytes total:15806135 bytes limit:16MB(16777216) FITS + payload:17568410 raw_torch:17615294 compression_ratio:3.86x +/workspace/parameter-golf/mytrain_gpt_v6_1.py:3216: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") +final_int8_zlib_roundtrip val_loss:2.1839 val_bpb:1.2934 eval_time:14683ms +final_int8_zlib_roundtrip_exact val_loss:2.18385094 val_bpb:1.29339954 +Starting PPM byte mixture evaluation... +ppm_mix_progress seq:500/60568 tokens:513024 bytes:1223626 contexts:330630 skipped_ctx:0 mix_bpb:0.772718 +ppm_mix_progress seq:1000/60568 tokens:1025024 bytes:2455049 contexts:498547 skipped_ctx:0 mix_bpb:0.791232 +ppm_mix_progress seq:1500/60568 tokens:1537024 bytes:3681252 contexts:626070 skipped_ctx:0 mix_bpb:0.798705 +ppm_mix_progress seq:2000/60568 tokens:2049024 bytes:4901593 contexts:743558 skipped_ctx:0 mix_bpb:0.806348 +ppm_mix_progress seq:2500/60568 tokens:2561024 bytes:6124160 contexts:844658 skipped_ctx:0 mix_bpb:0.809410 +ppm_mix_progress seq:3000/60568 tokens:3073024 bytes:7342115 contexts:940062 skipped_ctx:0 mix_bpb:0.812034 +ppm_mix_progress seq:3500/60568 tokens:3585024 bytes:8556619 contexts:1034335 skipped_ctx:0 mix_bpb:0.815524 +ppm_mix_progress seq:4000/60568 tokens:4097024 bytes:9782949 contexts:1114667 skipped_ctx:0 mix_bpb:0.817923 +ppm_mix_progress seq:4500/60568 tokens:4609024 bytes:11023557 contexts:1185904 skipped_ctx:0 mix_bpb:0.820011 +ppm_mix_progress seq:5000/60568 tokens:5121024 bytes:12250013 contexts:1260459 skipped_ctx:0 mix_bpb:0.821949 +ppm_mix_progress seq:5500/60568 tokens:5633024 bytes:13478107 contexts:1335757 skipped_ctx:0 mix_bpb:0.823841 +ppm_mix_progress seq:6000/60568 tokens:6145024 bytes:14694159 contexts:1408314 skipped_ctx:0 mix_bpb:0.825123 +ppm_mix_progress seq:6500/60568 tokens:6657024 bytes:15917346 contexts:1482153 skipped_ctx:0 mix_bpb:0.826809 +ppm_mix_progress seq:7000/60568 tokens:7169024 bytes:17147554 contexts:1545185 skipped_ctx:0 mix_bpb:0.827852 +ppm_mix_progress seq:7500/60568 tokens:7681024 bytes:18359660 contexts:1612554 skipped_ctx:0 mix_bpb:0.828923 +ppm_mix_bpb:0.829467 diff --git a/records/track_10min_16mb/2026-04-30_SP1024_ValueResid_PPMMix/README.md b/records/track_10min_16mb/2026-04-30_SP1024_ValueResid_PPMMix/README.md new file mode 100644 index 0000000000..c525343316 --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_SP1024_ValueResid_PPMMix/README.md @@ -0,0 +1,132 @@ +# SP1024 + Value Residual + Byte-Level PPM Mixture + +## Overview + +This submission is the result of an incremental research process rather than a single clean-sheet design. + +The training script was built step by step across many rounds of experiments. Instead of hard-coding one fixed model, we kept most architecture, optimization, tokenizer, and evaluation ideas behind environment-controlled switches so we could run controlled ablations quickly and compare many alternatives within one stable framework. + +The final submission in this folder is a **record 16MB submission** based on: + +- SentencePiece 1024 tokenizer +- 9-layer Transformer +- model dimension 512 +- 8 attention heads / 4 KV heads +- MLP multiplier 2 +- Value Residual enabled in the last 2 layers +- byte-level PPM mixture during final evaluation + +## Submission Type + +This is a **record submission**. + +The included best run was produced on **1×H100**, with 600s as the wall clock. We do **not** claim verified compliance with the official **8×H100 / 10-minute** leaderboard requirement in this folder. + +However, this run does satisfy the artifact-size requirement: + +- compressed model: `15,650,103 bytes` +- code size: `156,032 bytes` +- total submission size: `15,806,135 bytes` + +This fits under the 16MB limit. + +## Best Included Result + +### Neural roundtrip score +- `final_int8_zlib_roundtrip_exact val_bpb = 1.29339954` + +### Final mixed score +- `ppm_mix_bpb = 0.829467` + +This was the strongest included result for the SP1024 compact line. + +## Main Idea + +Our final direction is intentionally simple: + +1. keep a compact Transformer backbone +2. improve the late value path with **Value Residual** +3. combine the neural model with a **byte-level PPM mixture** at evaluation time + +In our experiments, this combination was more useful than continuing to add more complicated architectural branches. + +## How the Code Evolved + +This codebase was not written as a minimal one-off competition script. +It evolved as a research scaffold. + +Over time, we added switches for many ideas so that the same script could be reused for many sweeps and fair ablations. The broader script supports experimentation with: + +- tokenizer variants +- BiFPN / BiFPN2 skip fusion +- XSA +- N-gram augmentation +- Value Residual +- cross-layer V and KV sharing +- PLE +- MTP +- parallel residual variants +- parallel-v2 side lanes +- LoRA-TTT +- byte-level PPM mixture + +Many of these ideas were explored, but the strongest compact SP1024 line for this submission ended up being: + +**compact backbone + value residual + byte-level mixture** + +## Experimental Summary + +A short summary of the findings that most influenced this submission: + +### 1. Tokenizer choice mattered +Earlier sweeps showed that tokenizer choice had a large impact on compression performance. We explored SP1024, SP4096, and SP8192. For this submission, we chose SP1024 because it provided a compact, size-friendly line suitable for a 16MB submission. + +### 2. Capacity still mattered +Increasing backbone capacity often helped, but for this submission we prioritized a compact model that still achieved a strong mixed score while fitting under the 16MB limit. + +### 3. Value Residual was the strongest late-layer architectural improvement +Across many later Transformer ablations, **Value Residual** was the most consistent improvement that survived repeated testing. In this submission we enable it only in the last 2 layers. + +### 4. Byte-level PPM mixture produced the largest final gain +The final score improvement came primarily from combining the neural model with a **byte-level PPM mixture** rather than from continuing to add more neural-only complexity. + +## Final Configuration + +Key settings for the included run: + +- `VOCAB_SIZE=1024` +- `NUM_LAYERS=9` +- `MODEL_DIM=512` +- `NUM_HEADS=8` +- `NUM_KV_HEADS=4` +- `MLP_MULT=2` +- `VALUE_RESIDUAL_ENABLED=1` +- `VALUE_RESIDUAL_LAST_N_LAYERS=2` +- `BIFPN2_MODE=1` +- `XSA_ENABLED=1` +- `NGRAM_MAX_N=2` +- `EMA_ENABLED=1` +- `LATE_QAT_RATIO=0.15` +- `PPM_ENABLED=1` +- `PPM_ORDER=5` +- `PPM_CONF_THRESHOLD=0.9` +- `LAMBDA_LO=0.10` +- `LAMBDA_HI=0.75` + +## Included Files + +This folder contains: + +- `train_gpt.py` — final training and evaluation script +- `submission.json` — submission metadata +- `config.json` — selected configuration for the included run +- `requirements.txt` — Python dependencies +- `train.log` — log from the included best run +- `seed_runs.csv` — representative run summary + +## Reproduction + +A representative launch for the included run is equivalent to: + +```bash +torchrun --nproc_per_node=1 train_gpt.py diff --git a/records/track_10min_16mb/2026-04-30_SP1024_ValueResid_PPMMix/config.json b/records/track_10min_16mb/2026-04-30_SP1024_ValueResid_PPMMix/config.json new file mode 100644 index 0000000000..7c00105dbd --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_SP1024_ValueResid_PPMMix/config.json @@ -0,0 +1,110 @@ +{ + "DATA_PATH": "./data/datasets/fineweb10B_sp1024", + "TOKENIZER_PATH": "./data/tokenizers/fineweb_1024_bpe.model", + "NUM_LAYERS": 9, + "MODEL_DIM": 512, + "NUM_HEADS": 8, + "NUM_KV_HEADS": 4, + "MLP_MULT": 2, + "VOCAB_SIZE": 1024, + "TIE_EMBEDDINGS": 1, + "ROPE_BASE": 10000.0, + "ROPE_DIMS": -1, + "LEARNABLE_ROPE": 0, + "LOGIT_SOFTCAP": 30.0, + "QK_GAIN_INIT": 4.0, + "GRAD_ACCUM_STEPS": 4, + "TRAIN_BATCH_TOKENS": 524288, + "TRAIN_SEQ_LEN": 1024, + "ITERATIONS": 20000, + "WARMUP_STEPS": 20, + "WARMDOWN_ITERS": 900, + "MAX_WALLCLOCK_SECONDS": 600.0, + "VAL_BATCH_SIZE": 524288, + "VAL_LOSS_EVERY": 1000, + "TRAIN_LOG_EVERY": 200, + "MATRIX_LR": 0.04, + "SCALAR_LR": 0.04, + "EMBED_LR": 0.6, + "HEAD_LR": 0.008, + "TIED_EMBED_LR": 0.05, + "TIED_EMBED_INIT_STD": 0.005, + "MUON_MOMENTUM": 0.95, + "MUON_BACKEND_STEPS": 5, + "MUON_MOMENTUM_WARMUP_START": 0.85, + "MUON_MOMENTUM_WARMUP_STEPS": 500, + "BETA1": 0.9, + "BETA2": 0.95, + "ADAM_EPS": 1e-08, + "GRAD_CLIP_NORM": 0.0, + "FDA_MODE": 0, + "BIFPN_MODE": 0, + "BIFPN2_MODE": 1, + "BIFPN_GROUP_COUNT": 8, + "BIFPN_BAND_WIDTH": 1, + "BIFPN_NORM_EPS": 0.0001, + "BIFPN_INIT_MAIN": 1.0, + "BIFPN_INIT_NEIGHBOR": 0.15, + "BIFPN_INIT_FAR": 0.0, + "SMEAR_MODE": 0, + "SMEAR_WINDOW": 4, + "SMEAR_GATE": 0, + "LN_SCALE": 1, + "LEARNABLE_LN_SCALE": 0, + "AFFINE_NORM": 0, + "SCALEDLM_HEAD": 1, + "XSA_ENABLED": 1, + "XSA_LAST_N_LAYERS": 4, + "XSA_EPS": 1e-06, + "V_SKIP_ENABLED": 0, + "V_SKIP_LAST_N_LAYERS": 0, + "V_SKIP_MODE": "scalar", + "V_SKIP_GROUP_COUNT": 8, + "CROSS_LAYER_V_ENABLED": 0, + "CROSS_LAYER_V_LAST_N_LAYERS": 4, + "CROSS_LAYER_V_MODE": "residual", + "CROSS_LAYER_V_GROUP_COUNT": 4, + "CROSS_LAYER_KV_SHARING_ENABLED": 0, + "CROSS_LAYER_KV_LAST_N_LAYERS": 0, + "CROSS_LAYER_KV_SHARE_K": 1, + "CROSS_LAYER_KV_SHARE_V": 1, + "CROSS_LAYER_KV_PAIRWISE": 0, + "CROSS_LAYER_KV_PARTIAL_HEAD": 0, + "CROSS_LAYER_KV_PARTIAL_HEAD_COUNT": 2, + "VALUE_RESIDUAL_ENABLED": 1, + "VALUE_RESIDUAL_LAST_N_LAYERS": 2, + "VALUE_RESIDUAL_INIT_V0": 0.5, + "VALUE_RESIDUAL_INIT_CUR": 0.5, + "PLE_ENABLED": 0, + "MTP_NUM_HEADS": 0, + "NGRAM_VOCAB_SIZE": 2048, + "NGRAM_DIM": 128, + "NGRAM_MAX_N": 2, + "NGRAM_FADE_ENABLE": 1, + "NGRAM_FADE_START_FRAC": 0.15, + "NGRAM_FADE_END_FRAC": 0.45, + "NGRAM_FADE_MIN_SCALE": 0.0, + "EMA_ENABLED": 1, + "EMA_DECAY": 0.997, + "LATE_QAT_RATIO": 0.15, + "DYNAMIC_CLIP_PERCENTILES": "100.0,99.9999,99.9995,99.995,99.99,99.95,99.9,99.8", + "EVAL_USE_SLIDING_WINDOW": 0, + "EVAL_STRIDE": 1024, + "EVAL_BATCH_SEQS": 16, + "TELEMETRY_EVERY": 50, + "PROFILE_RUN": 0, + "PROFILE_WARMUP_STEPS": 5, + "PROFILE_ACTIVE_STEPS": 10, + "TTT_ENABLED": 0, + "LORA_TTT_ENABLED": 0, + "PPM_ENABLED": 1, + "PPM_ORDER": 5, + "PPM_SUBSET_TOKENS": 8000000, + "PPM_CONF_THRESHOLD": 0.9, + "LAMBDA_LO": 0.1, + "LAMBDA_HI": 0.75, + "NN_BYTE_PROJECTION": "spread_root", + "NN_BYTE_UNIFORM_FLOOR": 1e-06, + "STOP_MODE": "steps", + "MAX_TRAIN_STEPS": 2200 +} \ No newline at end of file diff --git a/records/track_10min_16mb/2026-04-30_SP1024_ValueResid_PPMMix/requirements.txt b/records/track_10min_16mb/2026-04-30_SP1024_ValueResid_PPMMix/requirements.txt new file mode 100644 index 0000000000..911b0e52f0 --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_SP1024_ValueResid_PPMMix/requirements.txt @@ -0,0 +1,10 @@ +numpy +tqdm +torch +huggingface-hub +kernels +setuptools +typing-extensions==4.15.0 +datasets +tiktoken +sentencepiece \ No newline at end of file diff --git a/records/track_10min_16mb/2026-04-30_SP1024_ValueResid_PPMMix/seed_runs.csv b/records/track_10min_16mb/2026-04-30_SP1024_ValueResid_PPMMix/seed_runs.csv new file mode 100644 index 0000000000..57a683fd69 --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_SP1024_ValueResid_PPMMix/seed_runs.csv @@ -0,0 +1,2 @@ +run_name,seed,tokenizer,model_dim,mlp_mult,num_layers,num_heads,num_kv_heads,value_residual_last_n_layers,ppm_order,ppm_conf_threshold,lambda_lo,lambda_hi,max_train_steps,world_size,grad_accum_steps,last_val_bpb,roundtrip_exact_val_bpb,ppm_mix_bpb,compressed_model_bytes,code_bytes,total_submission_bytes,fits_16mb,notes +record_try_sp1024_512d_mlp2_value_resid_last2_ppm_hi075_steps2200,1337,sp1024,512,2,9,8,4,2,5,0.9,0.10,0.75,2200,1,4,1.2894,1.29339954,0.829467,15650103,156032,15806135,true,"Best included single-H100 non-record run; artifact fits under 16MB" diff --git a/records/track_10min_16mb/2026-04-30_SP1024_ValueResid_PPMMix/train_gpt.py b/records/track_10min_16mb/2026-04-30_SP1024_ValueResid_PPMMix/train_gpt.py new file mode 100644 index 0000000000..2cde79e519 --- /dev/null +++ b/records/track_10min_16mb/2026-04-30_SP1024_ValueResid_PPMMix/train_gpt.py @@ -0,0 +1,3389 @@ +from __future__ import annotations + +""" +mytrain_gpt_v5_lora_ttt.py + +V5 mainline focused on the highest-priority path: + 1) Keep the current strong baseline backbone options. + 2) Add legal score-first LoRA-TTT as the main new feature. + 3) Keep the implementation torch.compile-friendly during normal training. + +This file is intended as a direct evolution target from your v4 script. +It contains the new flags, LoRA modules, TTT adapter wiring, and integration +points you can merge into the existing code. + +Design goals: +- Normal train/inference path remains compile-friendly. +- TTT path is isolated and runs after final int8 roundtrip eval. +- LoRA parameters are only used when TTT is enabled. +- Supports warm-start A / reset B, alpha/rank scaling, independent WD. +- Supports score-first legality: score a chunk first, then update on it. +""" + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torch.profiler +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +import json + + +# ============================================================ +# HYPERPARAMETERS +# ============================================================ + +class Hyperparameters: + # ----------------------------- + # Data + # ----------------------------- + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # ----------------------------- + # Validation / logging + # ----------------------------- + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # ----------------------------- + # Training schedule + # ----------------------------- + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + stop_mode = os.environ.get("STOP_MODE", "walltime") + max_train_steps = int(os.environ.get("MAX_TRAIN_STEPS", "0")) + + # ----------------------------- + # Sliding eval + # ----------------------------- + eval_use_sliding_window = bool(int(os.environ.get("EVAL_USE_SLIDING_WINDOW", "0"))) + eval_stride = int(os.environ.get("EVAL_STRIDE", "128")) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", "16")) + + # ----------------------------- + # Core model + # ----------------------------- + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + rope_dims = int(os.environ.get("ROPE_DIMS", "-1")) + learnable_rope = bool(int(os.environ.get("LEARNABLE_ROPE", "0"))) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # ----------------------------- + # Optimizer + # ----------------------------- + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # ----------------------------- + # Telemetry / profile + # ----------------------------- + output_dir = os.environ.get("OUTPUT_DIR", "") + telemetry_every = int(os.environ.get("TELEMETRY_EVERY", "0")) + telemetry_file = os.environ.get( + "TELEMETRY_FILE", + os.path.join(output_dir, "telemetry.jsonl") if output_dir else "logs/telemetry.jsonl", + ) + profile_run = bool(int(os.environ.get("PROFILE_RUN", "0"))) + profile_warmup_steps = int(os.environ.get("PROFILE_WARMUP_STEPS", "5")) + profile_active_steps = int(os.environ.get("PROFILE_ACTIVE_STEPS", "10")) + profile_step = int(os.environ.get("PROFILE_STEP", "-1")) + profile_output_dir = os.environ.get("PROFILE_OUTPUT_DIR", "output/prof_base") + + # ----------------------------- + # Baseline architecture flags + # ----------------------------- + fda_mode = bool(int(os.environ.get("FDA_MODE", "0"))) + bifpn_mode = bool(int(os.environ.get("BIFPN_MODE", "0"))) + bifpn2_mode = bool(int(os.environ.get("BIFPN2_MODE", "0"))) + bifpn_group_count = int(os.environ.get("BIFPN_GROUP_COUNT", "8")) + bifpn_band_width = int(os.environ.get("BIFPN_BAND_WIDTH", "1")) + bifpn_norm_eps = float(os.environ.get("BIFPN_NORM_EPS", "1e-4")) + bifpn_init_main = float(os.environ.get("BIFPN_INIT_MAIN", "1.0")) + bifpn_init_neighbor = float(os.environ.get("BIFPN_INIT_NEIGHBOR", "0.15")) + bifpn_init_far = float(os.environ.get("BIFPN_INIT_FAR", "0.0")) + + smear_mode = bool(int(os.environ.get("SMEAR_MODE", "0"))) + smear_window = int(os.environ.get("SMEAR_WINDOW", "4")) + smear_gate = bool(int(os.environ.get("SMEAR_GATE", "0"))) + + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + learnable_ln_scale = bool(int(os.environ.get("LEARNABLE_LN_SCALE", "0"))) + affine_norm = bool(int(os.environ.get("AFFINE_NORM", "0"))) + scaledlm_head = bool(int(os.environ.get("SCALEDLM_HEAD", "1"))) + + xsa_enabled = bool(int(os.environ.get("XSA_ENABLED", "0"))) + xsa_last_n_layers = int(os.environ.get("XSA_LAST_N_LAYERS", "0")) + xsa_eps = float(os.environ.get("XSA_EPS", "1e-6")) + + v_skip_enabled = bool(int(os.environ.get("V_SKIP_ENABLED", "0"))) + v_skip_last_n_layers = int(os.environ.get("V_SKIP_LAST_N_LAYERS", "0")) + v_skip_mode = os.environ.get("V_SKIP_MODE", "scalar") + v_skip_group_count = int(os.environ.get("V_SKIP_GROUP_COUNT", "8")) + + cross_layer_v_enabled = bool(int(os.environ.get("CROSS_LAYER_V_ENABLED", "0"))) + cross_layer_v_last_n_layers = int(os.environ.get("CROSS_LAYER_V_LAST_N_LAYERS", "0")) + cross_layer_v_mode = os.environ.get("CROSS_LAYER_V_MODE", "residual") + cross_layer_v_group_count = int(os.environ.get("CROSS_LAYER_V_GROUP_COUNT", "8")) + + cross_layer_kv_sharing_enabled = bool(int(os.environ.get("CROSS_LAYER_KV_SHARING_ENABLED", "0"))) + cross_layer_kv_last_n_layers = int(os.environ.get("CROSS_LAYER_KV_LAST_N_LAYERS", "0")) + cross_layer_kv_share_k = bool(int(os.environ.get("CROSS_LAYER_KV_SHARE_K", "1"))) + cross_layer_kv_share_v = bool(int(os.environ.get("CROSS_LAYER_KV_SHARE_V", "1"))) + cross_layer_kv_pairwise = bool(int(os.environ.get("CROSS_LAYER_KV_PAIRWISE", "0"))) + cross_layer_kv_partial_head = bool(int(os.environ.get("CROSS_LAYER_KV_PARTIAL_HEAD", "0"))) + cross_layer_kv_partial_head_count = int(os.environ.get("CROSS_LAYER_KV_PARTIAL_HEAD_COUNT", "2")) + cross_layer_kv_source_mode = os.environ.get("CROSS_LAYER_KV_SOURCE_MODE", "previous") + + # ----------------------------- + # Depth recurrence / value residual + # ----------------------------- + depth_recur_enabled = bool(int(os.environ.get("DEPTH_RECUR_ENABLED", "0"))) + num_stem_blocks = int(os.environ.get("NUM_STEM_BLOCKS", "3")) + num_core_blocks = int(os.environ.get("NUM_CORE_BLOCKS", "3")) + num_core_repeats = int(os.environ.get("NUM_CORE_REPEATS", "3")) + + value_residual_enabled = bool(int(os.environ.get("VALUE_RESIDUAL_ENABLED", "0"))) + value_residual_last_n_layers = int(os.environ.get("VALUE_RESIDUAL_LAST_N_LAYERS", "0")) + value_residual_init_v0 = float(os.environ.get("VALUE_RESIDUAL_INIT_V0", "0.5")) + value_residual_init_cur = float(os.environ.get("VALUE_RESIDUAL_INIT_CUR", "0.5")) + + # ----------------------------- + # PLE + # ----------------------------- + ple_enabled = bool(int(os.environ.get("PLE_ENABLED", "0"))) + ple_temporal_conv = bool(int(os.environ.get("PLE_TEMPORAL_CONV", "0"))) + ple_dim = int(os.environ.get("PLE_DIM", "32")) + ple_mode = os.environ.get("PLE_MODE", "post_attn") + ple_token_scale_init = float(os.environ.get("PLE_TOKEN_SCALE_INIT", "1.0")) + ple_ctx_scale_init = float(os.environ.get("PLE_CTX_SCALE_INIT", "1.0")) + ple_resid_scale_init = float(os.environ.get("PLE_RESID_SCALE_INIT", "0.1")) + + # ----------------------------- + # MTP + # ----------------------------- + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", "0")) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", "0.2")) + mtphead_mlpmode = bool(int(os.environ.get("MTPHEAD_MLPMODE", "0"))) + + # ----------------------------- + # N-gram + # ----------------------------- + ngram_vocab_size = int(os.environ.get("NGRAM_VOCAB_SIZE", "2048")) + ngram_dim = int(os.environ.get("NGRAM_DIM", "128")) + ngram_max_n = int(os.environ.get("NGRAM_MAX_N", "4")) + ngram_fade_enable = bool(int(os.environ.get("NGRAM_FADE_ENABLE", "0"))) + ngram_fade_start_frac = float(os.environ.get("NGRAM_FADE_START_FRAC", "0.15")) + ngram_fade_end_frac = float(os.environ.get("NGRAM_FADE_END_FRAC", "0.45")) + ngram_fade_min_scale = float(os.environ.get("NGRAM_FADE_MIN_SCALE", "0.0")) + + # ----------------------------- + # EMA / QAT + # ----------------------------- + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "0"))) + ema_decay = float(os.environ.get("EMA_DECAY", "0.997")) + late_qat_ratio = float(os.environ.get("LATE_QAT_RATIO", "0.15")) + dynamic_clip_percentiles = tuple( + float(x.strip()) + for x in os.environ.get( + "DYNAMIC_CLIP_PERCENTILES", + "100.0,99.9999,99.9995,99.995,99.99,99.95,99.9,99.8", + ).split(",") + if x.strip() + ) + + # ----------------------------- + # V5 MAIN FEATURE: LoRA-TTT + # ----------------------------- + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_mode = os.environ.get("TTT_MODE", "lora") # lora | full + ttt_lr = float(os.environ.get("TTT_LR", "0.002")) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", "1")) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", "49152")) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", "32")) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", "1.0")) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", "0.9")) + ttt_weight_decay = float(os.environ.get("TTT_WEIGHT_DECAY", "1.0")) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", "0")) + + lora_ttt_enabled = bool(int(os.environ.get("LORA_TTT_ENABLED", "0"))) + lora_ttt_rank = int(os.environ.get("LORA_TTT_RANK", "128")) + lora_ttt_alpha = float(os.environ.get("LORA_TTT_ALPHA", "144.0")) + lora_ttt_dropout = float(os.environ.get("LORA_TTT_DROPOUT", "0.0")) + lora_ttt_warm_start_a = bool(int(os.environ.get("LORA_TTT_WARM_START_A", "1"))) + lora_ttt_reset_b_each_chunk = bool(int(os.environ.get("LORA_TTT_RESET_B_EACH_CHUNK", "1"))) + lora_ttt_targets = os.environ.get( + "LORA_TTT_TARGETS", + "attn_q,attn_k,attn_v,attn_proj,mlp_fc,mlp_proj", + ) + + parallel_residual_enabled = bool(int(os.environ.get("PARALLEL_RESIDUAL_ENABLED", "0"))) + parallel_residual_last_n_layers = int(os.environ.get("PARALLEL_RESIDUAL_LAST_N_LAYERS", "0")) + parallel_residual_mode = os.environ.get("PARALLEL_RESIDUAL_MODE", "dual_add") # dual_add | gated_add + parallel_residual_init_attn = float(os.environ.get("PARALLEL_RESIDUAL_INIT_ATTN", "1.0")) + parallel_residual_init_mlp = float(os.environ.get("PARALLEL_RESIDUAL_INIT_MLP", "1.0")) + parallel_residual_gate_init = float(os.environ.get("PARALLEL_RESIDUAL_GATE_INIT", "0.0")) + + # ----------------------------- + # Parallel Residual v2 / hybrid second lane + # ----------------------------- + parallel_v2_enabled = bool(int(os.environ.get("PARALLEL_V2_ENABLED", "0"))) + parallel_v2_last_n_layers = int(os.environ.get("PARALLEL_V2_LAST_N_LAYERS", "2")) + parallel_v2_mode = os.environ.get("PARALLEL_V2_MODE", "dual_add") # dual_add | gated_add | delayed_merge + parallel_v2_second_lane = os.environ.get("PARALLEL_V2_SECOND_LANE", "mlp") # mlp | gated_linear | conv_gate | ssm + parallel_v2_init_attn = float(os.environ.get("PARALLEL_V2_INIT_ATTN", "1.0")) + parallel_v2_init_second = float(os.environ.get("PARALLEL_V2_INIT_SECOND", "1.0")) + parallel_v2_gate_init = float(os.environ.get("PARALLEL_V2_GATE_INIT", "0.0")) + parallel_v2_delayed_merge_steps = int(os.environ.get("PARALLEL_V2_DELAYED_MERGE_STEPS", "1")) + parallel_v2_norm_shared = bool(int(os.environ.get("PARALLEL_V2_NORM_SHARED", "1"))) + parallel_v2_use_post_attn_ple = bool(int(os.environ.get("PARALLEL_V2_USE_POST_ATTN_PLE", "0"))) + parallel_v2_log_norm_ratios = bool(int(os.environ.get("PARALLEL_V2_LOG_NORM_RATIOS", "0"))) + + gated_linear_mult = int(os.environ.get("GATED_LINEAR_MULT", "2")) + gated_linear_zero_init = bool(int(os.environ.get("GATED_LINEAR_ZERO_INIT", "1"))) + conv_gate_kernel_size = int(os.environ.get("CONV_GATE_KERNEL_SIZE", "4")) + ssm_state_dim = int(os.environ.get("SSM_STATE_DIM", "8")) + ssm_expand = int(os.environ.get("SSM_EXPAND", "2")) + ssm_conv_kernel = int(os.environ.get("SSM_CONV_KERNEL", "8")) + ssm_gate = bool(int(os.environ.get("SSM_GATE", "1"))) + + # ----------------------------- + # Lossless / PPM mixture eval + # ----------------------------- + ppm_enabled = bool(int(os.environ.get("PPM_ENABLED", "0"))) + ppm_order = int(os.environ.get("PPM_ORDER", "5")) + ppm_subset_tokens = int(os.environ.get("PPM_SUBSET_TOKENS", "0")) # 0 = full val + ppm_conf_threshold = float(os.environ.get("PPM_CONF_THRESHOLD", "0.9")) + lambda_lo = float(os.environ.get("LAMBDA_LO", "0.05")) + lambda_hi = float(os.environ.get("LAMBDA_HI", "0.9")) + ppm_max_contexts = int(os.environ.get("PPM_MAX_CONTEXTS", "0")) # 0 = unbounded + + # token->byte projection + nn_byte_projection = os.environ.get("NN_BYTE_PROJECTION", "spread_root") + nn_byte_uniform_floor = float(os.environ.get("NN_BYTE_UNIFORM_FLOOR", "1e-6")) + + # ----------------------------- + # Eval-only mode + # ----------------------------- + eval_only = bool(int(os.environ.get("EVAL_ONLY", "0"))) + checkpoint = os.environ.get("CHECKPOINT", "") + + +# ============================================================ +# COMPILE-FRIENDLY HELPERS +# ============================================================ + +@torch.compile(dynamic=False, fullgraph=True) +def update_ema_fused(ema_tensors: list[Tensor], model_tensors: list[Tensor], decay: float): + for e, m in zip(ema_tensors, model_tensors): + e.mul_(decay).add_(m.float(), alpha=1.0 - decay) + + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr: curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr: curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ============================================================ +# TOKENIZER-AGNOSTIC EVAL HELPERS +# ============================================================ + +def build_sentencepiece_luts(sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device): + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + +def build_sentencepiece_byte_tables( + sp: spm.SentencePieceProcessor, + vocab_size: int, +): + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + + token_piece_bytes: list[bytes] = [b""] * table_size + has_leading_space = [False] * table_size + is_boundary_token = [True] * table_size + + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + token_piece_bytes[token_id] = b"" + is_boundary_token[token_id] = True + continue + + is_boundary_token[token_id] = False + + if sp.is_byte(token_id): + piece = sp.id_to_piece(token_id) + # piece like <0xAB> + if piece.startswith("<0x") and piece.endswith(">") and len(piece) == 6: + token_piece_bytes[token_id] = bytes([int(piece[3:5], 16)]) + else: + token_piece_bytes[token_id] = b"" + continue + + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space[token_id] = True + piece = piece[1:] + token_piece_bytes[token_id] = piece.encode("utf-8") + + return token_piece_bytes, has_leading_space, is_boundary_token + + +def reconstruct_token_bytes( + prev_token_id: int, + token_id: int, + token_piece_bytes: list[bytes], + has_leading_space: list[bool], + is_boundary_token: list[bool], +) -> bytes: + base = token_piece_bytes[token_id] + if not base: + return b"" + if has_leading_space[token_id] and not is_boundary_token[prev_token_id]: + return b" " + base + return base + + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def tokens_to_bytes_count(xb: Tensor, yb: Tensor, base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor) -> Tensor: + prev_ids = xb.reshape(-1) + tgt_ids = yb.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + return token_bytes.sum() + + +# ============================================================ +# LoRA-TTT MODULES +# ============================================================ + +class BatchedLinearLoRA(nn.Module): + """ + Compile-friendly normal path: + y = base(x) + TTT path: + y = base(x) + scale * B(A(dropout(x))) + + Important details copied from the strong public PR direction: + - scale = alpha / rank + - A can be warm-started across chunks + - B can be reset each chunk + """ + def __init__(self, base: nn.Module, rank: int, alpha: float, dropout: float = 0.0): + super().__init__() + if not isinstance(base, (nn.Linear, CastedLinear)): + raise TypeError(f"Unsupported base module: {type(base)}") + self.base = base + self.in_features = base.in_features + self.out_features = base.out_features + self.rank = rank + self.alpha = alpha + self.scale = alpha / max(rank, 1) + self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + + self.lora_A = nn.Parameter(torch.empty(rank, self.in_features, dtype=torch.float32)) + self.lora_B = nn.Parameter(torch.empty(self.out_features, rank, dtype=torch.float32)) + self.lora_enabled = False + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + @torch.no_grad() + def reset_B_only(self): + nn.init.zeros_(self.lora_B) + + def forward(self, x: Tensor) -> Tensor: + y = self.base(x) + if not self.lora_enabled: + return y + x_d = self.dropout(x) + a = F.linear(x_d, self.lora_A.to(dtype=x.dtype), bias=None) + b = F.linear(a, self.lora_B.to(dtype=x.dtype), bias=None) + return y + self.scale * b + + +class LoRATTTManager: + def __init__(self, model: nn.Module, args: Hyperparameters): + self.model = model + self.args = args + self.targets = {s.strip() for s in args.lora_ttt_targets.split(",") if s.strip()} + self.adapters: dict[str, BatchedLinearLoRA] = {} + + def _want_module(self, name: str) -> bool: + mapping = { + "attn_q": ".attn.c_q", + "attn_k": ".attn.c_k", + "attn_v": ".attn.c_v", + "attn_proj": ".attn.proj", + "mlp_fc": ".mlp.fc", + "mlp_proj": ".mlp.proj", + } + for k, suffix in mapping.items(): + if k in self.targets and name.endswith(suffix): + return True + return False + + def inject(self): + replacements = [] + for name, module in self.model.named_modules(): + if self._want_module(name) and isinstance(module, (nn.Linear, CastedLinear)): + replacements.append((name, module)) + + for full_name, module in replacements: + parent_name, child_name = full_name.rsplit('.', 1) + parent = self.model.get_submodule(parent_name) + wrapped = BatchedLinearLoRA( + base=module, + rank=self.args.lora_ttt_rank, + alpha=self.args.lora_ttt_alpha, + dropout=self.args.lora_ttt_dropout, + ) + setattr(parent, child_name, wrapped) + self.adapters[full_name] = wrapped + + def set_enabled(self, enabled: bool): + for mod in self.adapters.values(): + mod.lora_enabled = enabled + + def lora_parameters(self): + for mod in self.adapters.values(): + yield mod.lora_A + yield mod.lora_B + + @torch.no_grad() + def reset_chunk_state(self): + for mod in self.adapters.values(): + if self.args.lora_ttt_reset_b_each_chunk: + mod.reset_B_only() + if not self.args.lora_ttt_warm_start_a: + nn.init.kaiming_uniform_(mod.lora_A, a=math.sqrt(5)) + + +# ============================================================ +# MODEL BUILDING BLOCKS +# ============================================================ + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + p for p in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,vr_lambda,parallel_attn_scale,parallel_mlp_scale,parallel_gate,parallel_v2_attn_scale,parallel_v2_second_scale,parallel_v2_gate,ssm_A_log,ssm_B,ssm_C", + ).split(",") if p +) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int | None = None, eps: float | None = None, affine: bool = False): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) if (affine and dim is not None) else None + + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) if self.weight is not None else None + return F.rms_norm(x, (x.size(-1),), weight=w, eps=self.eps) + + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, bias: bool = False): + super().__init__(in_features, out_features, bias=bias) + self.register_buffer("qat_alpha", torch.tensor(0.0, dtype=torch.float32), persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + w = self.weight + w_max = w.detach().abs().amax(dim=1, keepdim=True) + scale = (w_max / 127.0).clamp_min(1e-7) + w_quant = torch.clamp(torch.round(w / scale), -127, 127) * scale + w = w + (self.qat_alpha * (w_quant - w)).detach() + return F.linear(x, w.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype): + if self._cos_cached is None or self._sin_cached is None or self._seq_len_cached != seq_len or self._cos_cached.device != device: + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + half = rope_dims // 2 + x1 = x[..., :half] + x2 = x[..., half:rope_dims] + x_pass = x[..., rope_dims:] + cos_part = cos[..., :half] + sin_part = sin[..., :half] + return torch.cat(( + x1 * cos_part + x2 * sin_part, + x1 * (-sin_part) + x2 * cos_part, + x_pass, + ), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +def _expand_group_gates(g: Tensor, total_dim: int) -> Tensor: + if total_dim % g.numel() != 0: + raise ValueError(f"total_dim ({total_dim}) must be divisible by num_groups ({g.numel()})") + group_dim = total_dim // g.numel() + return g.repeat_interleave(group_dim) + + +def apply_v_skip(y: Tensor, v: Tensor, gate: Tensor, mode: str = "scalar", num_heads: int | None = None, num_kv_heads: int | None = None) -> Tensor: + b, h, t, d = y.shape + hkv = v.shape[1] + if h == hkv: + v_exp = v + else: + group_size = num_heads // num_kv_heads + v_exp = v.unsqueeze(2).expand(b, hkv, group_size, t, d).reshape(b, h, t, d) + if mode == "scalar": + g = torch.sigmoid(gate.to(dtype=y.dtype)).reshape(1, 1, 1, 1) + return y + g * v_exp + if mode == "group": + g = torch.sigmoid(gate.to(dtype=y.dtype)) + g = _expand_group_gates(g, d).view(1, 1, 1, d) + return y + g * v_exp + raise ValueError(f"Unknown V_SKIP_MODE: {mode}") + + +def mix_cross_layer_v(v_cur: Tensor, v_prev: Tensor, gate: Tensor, mode: str = "residual", group_mode: str = "scalar") -> Tensor: + d = v_cur.shape[-1] + if group_mode == "scalar": + g = torch.sigmoid(gate.to(dtype=v_cur.dtype)).reshape(1, 1, 1, 1) + elif group_mode == "group": + g = torch.sigmoid(gate.to(dtype=v_cur.dtype)) + g = _expand_group_gates(g, d).view(1, 1, 1, d) + else: + raise ValueError(f"Unknown cross-layer V group mode: {group_mode}") + if mode == "residual": + return v_cur + g * v_prev + if mode == "blend": + return (1.0 - g) * v_cur + g * v_prev + raise ValueError(f"Unknown CROSS_LAYER_V_MODE: {mode}") + + +def apply_partial_head_sharing(cur: Tensor, shared: Tensor, share_head_count: int) -> Tensor: + h = cur.shape[1] + n = min(max(share_head_count, 0), h) + if n == 0: + return cur + out = cur.clone() + out[:, :n] = shared[:, :n] + return out + + +def apply_xsa_gqa_efficient(y: Tensor, v: Tensor, num_heads: int, num_kv_heads: int, eps: float = 1e-6) -> Tensor: + if num_heads == num_kv_heads: + vn = v / (v.norm(dim=-1, keepdim=True) + eps) + proj = (y * vn).sum(dim=-1, keepdim=True) + return y - proj * vn + group_size = num_heads // num_kv_heads + b, h, t, d = y.shape + yg = y.view(b, num_kv_heads, group_size, t, d) + vn = v / (v.norm(dim=-1, keepdim=True) + eps) + vn = vn.unsqueeze(2) + proj = (yg * vn).sum(dim=-1, keepdim=True) + yg = yg - proj * vn + return yg.view(b, h, t, d) + + +class StructuredGroupSignedBiFPN(nn.Module): + def __init__(self, num_decoder_layers, num_encoder_layers, model_dim, group_count=8, band_width=1, norm_eps=1e-4, init_main=1.0, init_neighbor=0.15, init_far=0.0): + super().__init__() + if model_dim % group_count != 0: + raise ValueError(f"model_dim ({model_dim}) must be divisible by group_count ({group_count})") + self.num_decoder_layers = num_decoder_layers + self.num_encoder_layers = num_encoder_layers + self.model_dim = model_dim + self.group_count = group_count + self.group_dim = model_dim // group_count + self.band_width = band_width + self.norm_eps = norm_eps + w = torch.full((num_decoder_layers, num_encoder_layers, group_count), init_far, dtype=torch.float32) + for d in range(num_decoder_layers): + sym = num_encoder_layers - 1 - d + for e in range(num_encoder_layers): + dist_val = abs(e - sym) + if dist_val == 0: + w[d, e, :] = init_main + elif dist_val <= band_width: + w[d, e, :] = init_neighbor + mask = torch.zeros((num_decoder_layers, num_encoder_layers, 1), dtype=torch.float32) + for d in range(num_decoder_layers): + sym = num_encoder_layers - 1 - d + for e in range(num_encoder_layers): + if abs(e - sym) <= band_width: + mask[d, e, 0] = 1.0 + self.weights = nn.Parameter(w) + self.register_buffer("mask", mask, persistent=True) + + def forward(self, skips: list[Tensor], decoder_idx: int, x_dtype: torch.dtype) -> Tensor: + stacked = torch.stack(skips, dim=0) + enc, b, t, d = stacked.shape + stacked_g = stacked.view(enc, b, t, self.group_count, self.group_dim) + w = self.weights[decoder_idx] * self.mask[decoder_idx] + w = w.to(dtype=x_dtype) + denom = w.abs().sum(dim=0, keepdim=True).clamp_min(self.norm_eps) + w_norm = w / denom + fused = torch.einsum("eg,ebtgd->btgd", w_norm, stacked_g) + return fused.reshape(b, t, d) + + +class PLEModule(nn.Module): + def __init__(self, args: Hyperparameters): + super().__init__() + self.enabled = args.ple_enabled + self.temporal_conv = args.ple_temporal_conv + self.dim = args.ple_dim + self.mode = args.ple_mode + self.num_layers = args.num_layers + self.model_dim = args.model_dim + if not self.enabled: + return + self.token_embed = nn.Embedding(args.vocab_size, self.num_layers * self.dim) + self.ctx_proj = nn.Linear(self.model_dim, self.num_layers * self.dim, bias=False) + self.out_proj = nn.Linear(self.dim, self.model_dim, bias=False) + self.token_scale = nn.Parameter(torch.full((1,), args.ple_token_scale_init, dtype=torch.float32)) + self.ctx_scale = nn.Parameter(torch.full((1,), args.ple_ctx_scale_init, dtype=torch.float32)) + self.resid_scale = nn.Parameter(torch.full((1,), args.ple_resid_scale_init, dtype=torch.float32)) + if self.temporal_conv: + self.temporal = nn.Conv1d(self.dim, self.dim, kernel_size=3, padding=1, groups=self.dim, bias=False) + else: + self.temporal = None + + def build_all(self, token_ids: Tensor, x_embed: Tensor) -> Tensor | None: + if not self.enabled: + return None + tok = self.token_embed(token_ids).view(token_ids.shape[0], token_ids.shape[1], self.num_layers, self.dim) + ctx = self.ctx_proj(x_embed).view(token_ids.shape[0], token_ids.shape[1], self.num_layers, self.dim) + out = self.token_scale.to(tok.dtype) * tok + self.ctx_scale.to(tok.dtype) * ctx + if self.temporal is not None: + b, t, l, d = out.shape + tmp = out.permute(0, 2, 3, 1).reshape(b * l, d, t) + tmp = self.temporal(tmp) + out = tmp.reshape(b, l, d, t).permute(0, 3, 1, 2) + return out + + def apply(self, x: Tensor, ple_all: Tensor | None, layer_idx: int) -> Tensor: + if ple_all is None: + return x + p = ple_all[:, :, layer_idx, :] + p = self.out_proj(p) + return x + self.resid_scale.to(x.dtype) * p.to(x.dtype) + + +class CausalSelfAttention(nn.Module): + def __init__(self, args: Hyperparameters, layer_idx: int, xsa_enabled: bool = False, xsa_eps: float = 1e-6): + super().__init__() + dim = args.model_dim + num_heads = args.num_heads + num_kv_heads = args.num_kv_heads + rope_base = args.rope_base + qk_gain_init = args.qk_gain_init + + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + + self.layer_idx = layer_idx + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + + self.rope_dims = args.rope_dims if args.rope_dims > 0 else self.head_dim + self.xsa_enabled = xsa_enabled + self.xsa_eps = xsa_eps + self.learnable_rope = args.learnable_rope + + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + if self.learnable_rope: + init_logits = torch.full((self.head_dim // 2,), -4.0, dtype=torch.float32) + init_logits[:8] = 4.0 + self.rope_mix_logits = nn.Parameter(init_logits) + + self.v_skip_enabled = args.v_skip_enabled and (layer_idx >= args.num_layers - args.v_skip_last_n_layers) + self.v_skip_mode = args.v_skip_mode + self.v_skip_group_count = args.v_skip_group_count + if self.v_skip_enabled: + if self.v_skip_mode == "scalar": + self.v_skip_gate = nn.Parameter(torch.zeros(1, dtype=torch.float32)) + elif self.v_skip_mode == "group": + self.v_skip_gate = nn.Parameter(torch.zeros(self.v_skip_group_count, dtype=torch.float32)) + else: + raise ValueError(f"Unknown V_SKIP_MODE: {self.v_skip_mode}") + + self.cross_layer_v_enabled = args.cross_layer_v_enabled and (layer_idx >= args.num_layers - args.cross_layer_v_last_n_layers) + self.cross_layer_v_mode = args.cross_layer_v_mode + self.cross_layer_v_group_count = args.cross_layer_v_group_count + if self.cross_layer_v_enabled: + if args.cross_layer_v_group_count <= 1: + self.cross_layer_v_gate = nn.Parameter(torch.zeros(1, dtype=torch.float32)) + self.cross_layer_v_gate_mode = "scalar" + else: + self.cross_layer_v_gate = nn.Parameter(torch.zeros(args.cross_layer_v_group_count, dtype=torch.float32)) + self.cross_layer_v_gate_mode = "group" + + self.cross_layer_kv_sharing_enabled = args.cross_layer_kv_sharing_enabled and (layer_idx >= args.num_layers - args.cross_layer_kv_last_n_layers) + self.cross_layer_kv_share_k = args.cross_layer_kv_share_k + self.cross_layer_kv_share_v = args.cross_layer_kv_share_v + self.cross_layer_kv_pairwise = args.cross_layer_kv_pairwise + self.cross_layer_kv_partial_head = args.cross_layer_kv_partial_head + self.cross_layer_kv_partial_head_count = args.cross_layer_kv_partial_head_count + + self.value_residual_enabled = args.value_residual_enabled and (layer_idx >= args.num_layers - args.value_residual_last_n_layers) + if self.value_residual_enabled: + self.vr_lambda = nn.Parameter(torch.tensor([args.value_residual_init_v0, args.value_residual_init_cur], dtype=torch.float32)) + + def forward(self, x: Tensor, shared_k: Tensor | None = None, shared_v: Tensor | None = None, prev_v: Tensor | None = None, v0: Tensor | None = None): + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + raw_v = v + + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + + cos, sin = self.rotary(seqlen, x.device, q.dtype) + if self.learnable_rope: + q_rot = apply_rotary_emb(q, cos, sin, rope_dims=0) + k_rot = apply_rotary_emb(k, cos, sin, rope_dims=0) + gamma = torch.sigmoid(self.rope_mix_logits.to(q.dtype)) + gamma = gamma.unsqueeze(-1).expand(-1, 2).reshape(-1) + q = gamma * q_rot + (1 - gamma) * q + k = gamma * k_rot + (1 - gamma) * k + else: + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + + k_eff = k + v_eff = v + + if self.cross_layer_kv_sharing_enabled: + if self.cross_layer_kv_share_k and shared_k is not None: + if self.cross_layer_kv_partial_head: + k_eff = apply_partial_head_sharing(k_eff, shared_k, self.cross_layer_kv_partial_head_count) + else: + k_eff = shared_k + if self.cross_layer_kv_share_v and shared_v is not None: + if self.cross_layer_kv_partial_head: + v_eff = apply_partial_head_sharing(v_eff, shared_v, self.cross_layer_kv_partial_head_count) + else: + v_eff = shared_v + + if self.cross_layer_v_enabled and prev_v is not None: + v_eff = mix_cross_layer_v(v_eff, prev_v, self.cross_layer_v_gate, mode=self.cross_layer_v_mode, group_mode=self.cross_layer_v_gate_mode) + + if self.value_residual_enabled and v0 is not None: + lam = self.vr_lambda.to(dtype=v_eff.dtype) + v_eff = lam[0] * v0 + lam[1] * v_eff + + y = F.scaled_dot_product_attention( + q, k_eff, v_eff, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + + if self.xsa_enabled: + y = apply_xsa_gqa_efficient(y=y, v=v_eff, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, eps=self.xsa_eps) + + if self.v_skip_enabled: + y = apply_v_skip(y=y, v=v_eff, gate=self.v_skip_gate, mode=self.v_skip_mode, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads) + + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + out = self.proj(y) + return out, k_eff, v_eff, raw_v + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class GatedLinearLane(nn.Module): + def __init__(self, dim: int, mult: int = 2, zero_init: bool = True): + super().__init__() + self.hidden_dim = max(1, int(mult)) * dim + self.up_proj = CastedLinear(dim, self.hidden_dim, bias=False) + self.gate_proj = CastedLinear(dim, self.hidden_dim, bias=False) + self.down_proj = CastedLinear(self.hidden_dim, dim, bias=False) + if zero_init: + self.down_proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + return self.down_proj(self.up_proj(x) * F.silu(self.gate_proj(x))) + + +class ConvGateLane(nn.Module): + def __init__(self, dim: int, kernel_size: int = 4): + super().__init__() + self.kernel_size = max(1, int(kernel_size)) + self.dwconv = nn.Conv1d(dim, dim, kernel_size=self.kernel_size, groups=dim, bias=False) + self.gate_proj = CastedLinear(dim, dim, bias=False) + self.pointwise = CastedLinear(dim, dim, bias=False) + self.pointwise._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + y = x.transpose(1, 2) + if self.kernel_size > 1: + y = F.pad(y, (self.kernel_size - 1, 0)) + y = self.dwconv(y).transpose(1, 2) + return self.pointwise(y * torch.sigmoid(self.gate_proj(x))) + + +class TinySSMLane(nn.Module): + def __init__(self, dim: int, state_dim: int = 8, expand: int = 2, conv_kernel: int = 8, gate: bool = True): + super().__init__() + self.inner_dim = max(1, int(expand)) * dim + self.state_dim = max(1, int(state_dim)) + self.conv_kernel = max(1, int(conv_kernel)) + self.use_gate = bool(gate) + + self.in_proj = CastedLinear(dim, self.inner_dim, bias=False) + self.out_proj = CastedLinear(self.inner_dim, dim, bias=False) + self.out_proj._zero_init = True + self.gate_proj = CastedLinear(dim, self.inner_dim, bias=False) if self.use_gate else None + + self.ssm_A_log = nn.Parameter(torch.zeros(self.inner_dim, self.state_dim, dtype=torch.float32)) + self.ssm_B = nn.Parameter(torch.randn(self.inner_dim, self.state_dim, dtype=torch.float32) * 0.02) + self.ssm_C = nn.Parameter(torch.randn(self.inner_dim, self.state_dim, dtype=torch.float32) * 0.02) + self.register_buffer("ssm_t", torch.arange(self.conv_kernel, dtype=torch.float32), persistent=False) + + def _causal_kernel(self, dtype: torch.dtype, device: torch.device) -> Tensor: + t = self.ssm_t.to(device=device, dtype=torch.float32) + decay = torch.exp(-F.softplus(self.ssm_A_log.float()).unsqueeze(-1) * t.view(1, 1, -1)) + weights = (self.ssm_B.float() * self.ssm_C.float()).unsqueeze(-1) * decay + kernel = weights.sum(dim=1).to(dtype=dtype).flip(-1) + return kernel.unsqueeze(1) + + def forward(self, x: Tensor) -> Tensor: + u = self.in_proj(x) + y = u.transpose(1, 2) + if self.conv_kernel > 1: + y = F.pad(y, (self.conv_kernel - 1, 0)) + y = F.conv1d(y, self._causal_kernel(dtype=u.dtype, device=u.device), groups=self.inner_dim).transpose(1, 2) + if self.gate_proj is not None: + y = y * torch.sigmoid(self.gate_proj(x)) + return self.out_proj(F.silu(y)) + + +def count_trainable_params(module: nn.Module | None) -> int: + if module is None: + return 0 + return sum(p.numel() for p in module.parameters() if p.requires_grad) + + +class Block(nn.Module): + def __init__(self, args: Hyperparameters, layer_idx=0, xsa_enabled=False, xsa_eps=1e-6): + super().__init__() + dim = args.model_dim + self.layer_idx = layer_idx + self.ple_mode = args.ple_mode + + self.attn_norm = RMSNorm(dim, affine=args.affine_norm) + self.mlp_norm = RMSNorm(dim, affine=args.affine_norm) + self.attn = CausalSelfAttention(args, layer_idx=layer_idx, xsa_enabled=xsa_enabled, xsa_eps=xsa_eps) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + self.parallel_residual_enabled = ( + args.parallel_residual_enabled and + (layer_idx >= args.num_layers - args.parallel_residual_last_n_layers) + ) + self.parallel_residual_mode = args.parallel_residual_mode + + self.parallel_v2_enabled = ( + args.parallel_v2_enabled and + (layer_idx >= args.num_layers - args.parallel_v2_last_n_layers) + ) + self.parallel_v2_mode = args.parallel_v2_mode + self.parallel_v2_second_lane_name = args.parallel_v2_second_lane + self.parallel_v2_norm_shared = args.parallel_v2_norm_shared + self.parallel_v2_use_post_attn_ple = args.parallel_v2_use_post_attn_ple + self.parallel_v2_delayed_merge = self.parallel_v2_enabled and self.parallel_v2_mode == "delayed_merge" + self.parallel_v2_capture_norm_ratios = False + + if self.parallel_residual_enabled and self.parallel_v2_enabled: + raise ValueError("PARALLEL_RESIDUAL_ENABLED and PARALLEL_V2_ENABLED cannot both apply to the same layer") + + needs_mlp = (not self.parallel_v2_enabled) or self.parallel_residual_enabled or self.parallel_v2_second_lane_name == "mlp" + self.mlp = MLP(dim, args.mlp_mult) if needs_mlp else None + + self.second_lane: nn.Module | None = None + if self.parallel_v2_enabled and self.parallel_v2_second_lane_name != "mlp": + if self.parallel_v2_second_lane_name == "gated_linear": + self.second_lane = GatedLinearLane(dim, mult=args.gated_linear_mult, zero_init=args.gated_linear_zero_init) + elif self.parallel_v2_second_lane_name == "conv_gate": + self.second_lane = ConvGateLane(dim, kernel_size=args.conv_gate_kernel_size) + elif self.parallel_v2_second_lane_name == "ssm": + self.second_lane = TinySSMLane( + dim, + state_dim=args.ssm_state_dim, + expand=args.ssm_expand, + conv_kernel=args.ssm_conv_kernel, + gate=args.ssm_gate, + ) + else: + raise ValueError(f"Unknown PARALLEL_V2_SECOND_LANE: {self.parallel_v2_second_lane_name}") + + self.learnable_ln_scale = args.learnable_ln_scale + init_scale = 1.0 / (math.sqrt(layer_idx + 1) + 0.1 * layer_idx) if args.ln_scale else 1.0 + if self.learnable_ln_scale: + self.layer_scale = nn.Parameter(torch.tensor([init_scale], dtype=torch.float32)) + else: + self.layer_scale = init_scale + + if self.parallel_residual_enabled: + self.parallel_attn_scale = nn.Parameter( + torch.full((dim,), args.parallel_residual_init_attn, dtype=torch.float32) + ) + self.parallel_mlp_scale = nn.Parameter( + torch.full((dim,), args.parallel_residual_init_mlp, dtype=torch.float32) + ) + if self.parallel_residual_mode == "gated_add": + self.parallel_gate = nn.Parameter( + torch.full((dim,), args.parallel_residual_gate_init, dtype=torch.float32) + ) + + if self.parallel_v2_enabled: + self.attn_scale.requires_grad_(False) + self.mlp_scale.requires_grad_(False) + if self.parallel_v2_norm_shared and self.mlp_norm.weight is not None: + self.mlp_norm.weight.requires_grad_(False) + self.parallel_v2_attn_scale = nn.Parameter( + torch.full((dim,), args.parallel_v2_init_attn, dtype=torch.float32) + ) + self.parallel_v2_second_scale = nn.Parameter( + torch.full((dim,), args.parallel_v2_init_second, dtype=torch.float32) + ) + if self.parallel_v2_mode == "gated_add": + self.parallel_v2_gate = nn.Parameter( + torch.full((dim,), args.parallel_v2_gate_init, dtype=torch.float32) + ) + self.register_buffer("parallel_v2_attn_norm_ratio", torch.tensor(float("nan"), dtype=torch.float32), persistent=False) + self.register_buffer("parallel_v2_second_norm_ratio", torch.tensor(float("nan"), dtype=torch.float32), persistent=False) + + def _apply_ple(self, x: Tensor, ple_all: Tensor | None, ple_apply, mode: str) -> Tensor: + if ple_apply is not None and self.ple_mode == mode: + return ple_apply(x, ple_all, self.layer_idx) + return x + + def _parallel_v2_second_lane(self, x: Tensor) -> Tensor: + if self.parallel_v2_second_lane_name == "mlp": + if self.mlp is None: + raise RuntimeError("MLP lane was not constructed") + return self.mlp(x) + if self.second_lane is None: + raise RuntimeError(f"Second lane was not constructed: {self.parallel_v2_second_lane_name}") + return self.second_lane(x) + + def _parallel_v2_merge(self, x_base: Tensor, attn_out: Tensor, second_out: Tensor) -> Tensor: + attn_delta, second_delta = self._parallel_v2_scaled_lanes(attn_out, second_out, x_base.dtype) + if self.parallel_v2_mode == "dual_add" or self.parallel_v2_mode == "delayed_merge": + return x_base + attn_delta + second_delta + if self.parallel_v2_mode == "gated_add": + gate = torch.sigmoid(self.parallel_v2_gate.to(dtype=x_base.dtype))[None, None, :] + return x_base + gate * attn_delta + (1.0 - gate) * second_delta + raise ValueError(f"Unknown PARALLEL_V2_MODE: {self.parallel_v2_mode}") + + def _parallel_v2_scaled_lanes(self, attn_out: Tensor, second_out: Tensor, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + attn_scale = self.parallel_v2_attn_scale.to(dtype=dtype)[None, None, :] + second_scale = self.parallel_v2_second_scale.to(dtype=dtype)[None, None, :] + return attn_scale * attn_out, second_scale * second_out + + def _record_parallel_v2_norm_ratios(self, x_pre: Tensor, attn_out: Tensor, second_out: Tensor) -> None: + if self.parallel_v2_capture_norm_ratios: + denom = x_pre.detach().float().norm().clamp_min(1e-8) + self.parallel_v2_attn_norm_ratio.copy_(attn_out.detach().float().norm() / denom) + self.parallel_v2_second_norm_ratio.copy_(second_out.detach().float().norm() / denom) + + + def forward( + self, + x: Tensor, + x0: Tensor, + ple_all: Tensor | None = None, + ple_apply=None, + shared_k: Tensor | None = None, + shared_v: Tensor | None = None, + prev_v: Tensor | None = None, + v0: Tensor | None = None, + ): + # -------------------------------------------------- + # 1) residual pre-mix + # -------------------------------------------------- + mix = self.resid_mix.to(dtype=x.dtype) + x_base = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # -------------------------------------------------- + # 2) optional PLE before attention / branch split + # -------------------------------------------------- + x_pre = self._apply_ple(x_base, ple_all, ple_apply, "pre_attn") + + # -------------------------------------------------- + # 3) shared scale + # -------------------------------------------------- + scale = self.layer_scale.to(dtype=x.dtype) if isinstance(self.layer_scale, nn.Parameter) else self.layer_scale + + # -------------------------------------------------- + # 4) attention branch + # -------------------------------------------------- + attn_in = self.attn_norm(x_pre) * scale + attn_out, k_eff, v_eff, raw_v = self.attn( + attn_in, + shared_k=shared_k, + shared_v=shared_v, + prev_v=prev_v, + v0=v0, + ) + + # -------------------------------------------------- + # 5) serial vs parallel residual update + # -------------------------------------------------- + if self.parallel_v2_enabled: + if ple_apply is not None and self.ple_mode == "post_attn" and not self.parallel_v2_use_post_attn_ple: + raise ValueError("PLE mode 'post_attn' is not supported when PARALLEL_V2_ENABLED=1") + + if self.parallel_v2_norm_shared: + second_in = attn_in + else: + second_in = self.mlp_norm(x_pre) * scale + second_out = self._parallel_v2_second_lane(second_in) + self._record_parallel_v2_norm_ratios(x_pre, attn_out, second_out) + + if self.parallel_v2_delayed_merge: + return x_pre, k_eff, v_eff, raw_v, attn_out, second_out + + x_out = self._parallel_v2_merge(x_pre, attn_out, second_out) + if self.parallel_v2_use_post_attn_ple: + x_out = self._apply_ple(x_out, ple_all, ple_apply, "post_attn") + + elif self.parallel_residual_enabled: + # In parallel mode, attention and MLP both see the same pre-branch state. + # This is the key difference from the serial Transformer block. + if self.mlp is None: + raise RuntimeError("MLP lane was not constructed") + mlp_in = self.mlp_norm(x_pre) * scale + mlp_out = self.mlp(mlp_in) + + attn_scale = self.parallel_attn_scale.to(dtype=x.dtype)[None, None, :] + mlp_scale = self.parallel_mlp_scale.to(dtype=x.dtype)[None, None, :] + + if self.parallel_residual_mode == "dual_add": + x_out = x_pre + attn_scale * attn_out + mlp_scale * mlp_out + elif self.parallel_residual_mode == "gated_add": + gate = torch.sigmoid(self.parallel_gate.to(dtype=x.dtype))[None, None, :] + x_out = x_pre + gate * (attn_scale * attn_out) + (1.0 - gate) * (mlp_scale * mlp_out) + else: + raise ValueError(f"Unknown PARALLEL_RESIDUAL_MODE: {self.parallel_residual_mode}") + + # NOTE: + # In parallel mode, "post_attn" becomes ambiguous because there is no single + # canonical state that is "after attention but before MLP" anymore. + # To keep semantics clean, we do not support post_attn PLE in parallel mode. + if ple_apply is not None and self.ple_mode == "post_attn": + raise ValueError("PLE mode 'post_attn' is not supported when PARALLEL_RESIDUAL_ENABLED=1") + + else: + # Serial Transformer block: + # MLP consumes the post-attention hidden state. + if self.mlp is None: + raise RuntimeError("MLP lane was not constructed") + x_after_attn = x_pre + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + + x_after_attn = self._apply_ple(x_after_attn, ple_all, ple_apply, "post_attn") + + mlp_in = self.mlp_norm(x_after_attn) * scale + mlp_out = self.mlp(mlp_in) + + x_out = x_after_attn + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * mlp_out + + # -------------------------------------------------- + # 6) optional PLE after FFN / final merge + # -------------------------------------------------- + x_out = self._apply_ple(x_out, ple_all, ple_apply, "post_ffn") + + return x_out, k_eff, v_eff, raw_v + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class CausalLocalMixing(nn.Module): + def __init__(self, dim: int, window_size: int = 4): + super().__init__() + self.window_size = window_size + self.dim = dim + w = torch.zeros(window_size, dim, dtype=torch.float32) + w[0, :] = 3.0 + self.mix_logits = nn.Parameter(w) + + def forward(self, x: Tensor) -> Tensor: + if self.window_size <= 1: + return x + w_soft = F.softmax(self.mix_logits.to(x.dtype), dim=0) + kernel = w_soft.flip(0).T.unsqueeze(1) + x_t = x.transpose(1, 2) + x_padded = F.pad(x_t, (self.window_size - 1, 0)) + out = F.conv1d(x_padded, kernel, groups=self.dim) + return out.transpose(1, 2) + + +class NGramHashEmbedding(nn.Module): + def __init__(self, vocab_size: int, dim: int, model_dim: int, max_n: int = 4): + super().__init__() + self.max_n = max_n + self.vocab_size = vocab_size + self.embeds = nn.ModuleList([nn.Embedding(vocab_size, dim) for _ in range(2, max_n + 1)]) + for emb in self.embeds: + nn.init.normal_(emb.weight, std=0.01) + self.proj = nn.Linear(dim, model_dim, bias=False) if dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.ngram_scales = nn.Parameter(torch.full((max_n - 1,), 0.05, dtype=torch.float32)) + + def ngram_hash(self, tokens: Tensor, n: int) -> Tensor: + t = tokens.to(torch.int64) + mod = self.vocab_size - 1 + out = torch.empty_like(t) + out[..., :n - 1] = mod + primes = [36313, 27191, 19393, 13127, 9767] + hash_val = t[..., n - 1:] * primes[0] + for i in range(1, n): + hash_val = torch.bitwise_xor(hash_val, t[..., n - 1 - i: -i] * primes[i]) + out[..., n - 1:] = hash_val % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + fused_h = None + for idx, n in enumerate(range(2, self.max_n + 1)): + h_n = self.embeds[idx](self.ngram_hash(token_ids, n)) + scaled_h = h_n * self.ngram_scales[idx].to(dtype=h_n.dtype) + fused_h = scaled_h if fused_h is None else fused_h + scaled_h + if self.proj is not None: + fused_h = self.proj(fused_h) + return fused_h + + +def compute_ngram_fade_scale(step, total_steps, enabled, start_frac, end_frac, min_scale=0.0) -> float: + if not enabled: + return 1.0 + if total_steps <= 0: + return 1.0 + p = step / float(total_steps) + start_frac = max(0.0, min(1.0, start_frac)) + end_frac = max(start_frac + 1e-8, min(1.0, end_frac)) + min_scale = max(0.0, min(1.0, min_scale)) + if p <= start_frac: + return 1.0 + if p >= end_frac: + return min_scale + alpha = (p - start_frac) / (end_frac - start_frac) + return (1.0 - alpha) + alpha * min_scale + + +class GPT(nn.Module): + def __init__(self, args: Hyperparameters, master_process: bool = True): + super().__init__() + self.args = args + self.fda_mode = args.fda_mode + self.skip_distance = 2 + self.num_layers = args.num_layers + self.cross_layer_kv_sharing_enabled = args.cross_layer_kv_sharing_enabled + self.cross_layer_kv_last_n_layers = args.cross_layer_kv_last_n_layers + self.cross_layer_kv_pairwise = args.cross_layer_kv_pairwise + self.tie_embeddings = args.tie_embeddings + self.tied_embed_init_std = args.tied_embed_init_std + self.logit_softcap = args.logit_softcap + self.scaledlm_head = args.scaledlm_head + self.mtphead_mlpmode = args.mtphead_mlpmode + self.depth_recur_enabled = args.depth_recur_enabled + self.num_stem_blocks = args.num_stem_blocks + self.num_core_blocks = args.num_core_blocks + self.num_core_repeats = args.num_core_repeats + self.num_tail_blocks = args.num_stem_blocks + args.parallel_v2_mode = args.parallel_v2_mode.strip().lower() + args.parallel_v2_second_lane = args.parallel_v2_second_lane.strip().lower() + self.parallel_v2_enabled = args.parallel_v2_enabled + self.parallel_v2_mode = args.parallel_v2_mode + self.parallel_v2_has_delayed_merge = ( + args.parallel_v2_enabled and + args.parallel_v2_mode == "delayed_merge" and + args.parallel_v2_last_n_layers > 0 + ) + if args.parallel_v2_enabled: + if args.parallel_residual_enabled: + raise ValueError("PARALLEL_RESIDUAL_ENABLED and PARALLEL_V2_ENABLED are mutually exclusive") + if args.parallel_v2_mode not in {"dual_add", "gated_add", "delayed_merge"}: + raise ValueError(f"Unknown PARALLEL_V2_MODE: {args.parallel_v2_mode}") + if args.parallel_v2_second_lane not in {"mlp", "gated_linear", "conv_gate", "ssm"}: + raise ValueError(f"Unknown PARALLEL_V2_SECOND_LANE: {args.parallel_v2_second_lane}") + if args.parallel_v2_last_n_layers < 0: + raise ValueError("PARALLEL_V2_LAST_N_LAYERS must be >= 0") + if args.parallel_v2_mode == "delayed_merge" and args.parallel_v2_delayed_merge_steps != 1: + raise ValueError("PARALLEL_V2_DELAYED_MERGE_STEPS currently only supports 1") + if args.ple_enabled and args.ple_mode == "post_attn" and not args.parallel_v2_use_post_attn_ple: + raise ValueError("Set PARALLEL_V2_USE_POST_ATTN_PLE=1 to opt into post-merge PLE semantics") + + model_dim = args.model_dim + num_layers = args.num_layers + self.tok_emb = nn.Embedding(args.vocab_size, model_dim) + + self.ple = PLEModule(args) + + self.smear_mode = args.smear_mode + if self.smear_mode: + self.local_mix = CausalLocalMixing(model_dim, window_size=args.smear_window) + if master_process: + print(f"Architecture: Local Causal Mixing (Window={args.smear_window})") + self.smear_gate = args.smear_gate + if self.smear_gate: + self.smear_gate_module = SmearGate(model_dim) + if master_process: + print("Architecture: SmearGate (1-step causal blend)") + + self.ngram_max_n = args.ngram_max_n + if args.ngram_vocab_size > 0 and self.ngram_max_n >= 2: + self.ngram = NGramHashEmbedding(args.ngram_vocab_size, args.ngram_dim, model_dim, max_n=self.ngram_max_n) + if master_process: + print(f"Architecture: Discrete N-Gram Hash (Max N={self.ngram_max_n})") + else: + self.ngram = None + self.register_buffer("ngram_global_scale_buf", torch.tensor(1.0, dtype=torch.float32), persistent=False) + + self.blocks = nn.ModuleList([ + Block( + args, + layer_idx=i, + xsa_enabled=(args.xsa_enabled and i >= num_layers - args.xsa_last_n_layers), + xsa_eps=args.xsa_eps, + ) for i in range(num_layers) + ]) + + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.bifpn_mode = args.bifpn_mode + self.bifpn2_mode = args.bifpn2_mode + + if self.depth_recur_enabled: + self.skip_weights = nn.Parameter(torch.ones(self.num_stem_blocks, model_dim, dtype=torch.float32)) + elif self.bifpn_mode: + w = torch.full((self.num_decoder_layers, self.num_encoder_layers), 0.1, dtype=torch.float32) + for i in range(self.num_decoder_layers): + sym_idx = self.num_encoder_layers - 1 - i + if sym_idx >= 0: + w[i, sym_idx] = 1.0 + self.bifpn_weights = nn.Parameter(w) + elif self.bifpn2_mode: + self.structured_bifpn = StructuredGroupSignedBiFPN( + num_decoder_layers=self.num_decoder_layers, + num_encoder_layers=self.num_encoder_layers, + model_dim=model_dim, + group_count=args.bifpn_group_count, + band_width=args.bifpn_band_width, + norm_eps=args.bifpn_norm_eps, + init_main=args.bifpn_init_main, + init_neighbor=args.bifpn_init_neighbor, + init_far=args.bifpn_init_far, + ) + elif self.fda_mode: + num_conn = max(0, num_layers - self.skip_distance) + self.skip_weights = nn.Parameter(torch.ones(num_conn, model_dim, dtype=torch.float32)) + else: + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + + self.final_norm = RMSNorm(model_dim, affine=args.affine_norm) + self.lm_head = None if args.tie_embeddings else CastedLinear(model_dim, args.vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + + self.mtp_num_heads = args.mtp_num_heads + self.mtp_loss_weight = args.mtp_loss_weight + if self.mtp_num_heads > 0: + if self.mtphead_mlpmode: + self.mtp_heads = nn.ModuleList([ + nn.Sequential( + nn.Linear(model_dim, model_dim * 2, bias=False), + nn.GELU(), + nn.Linear(model_dim * 2, args.vocab_size, bias=False), + ) for _ in range(self.mtp_num_heads) + ]) + else: + self.mtp_heads = nn.ModuleList([CastedLinear(model_dim, args.vocab_size, bias=False) for _ in range(self.mtp_num_heads)]) + else: + self.mtp_heads = nn.ModuleList([]) + + self.max_logit_pre_cap = 0.0 + self._init_weights() + + def _init_weights(self): + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def set_parallel_v2_norm_capture(self, enabled: bool) -> None: + for block in self.blocks: + if getattr(block, "parallel_v2_enabled", False): + block.parallel_v2_capture_norm_ratios = enabled + if enabled: + block.parallel_v2_attn_norm_ratio.fill_(float("nan")) + block.parallel_v2_second_norm_ratio.fill_(float("nan")) + + def _should_share_from_prev_layer(self, block_idx: int) -> bool: + if not self.cross_layer_kv_sharing_enabled: + return False + return block_idx >= self.num_layers - self.cross_layer_kv_last_n_layers + + def _merge_parallel_v2_pending( + self, + x: Tensor, + ple_all: Tensor | None, + ple_apply, + pending_layer_idx: int, + pending_attn_delta: Tensor | None, + pending_second_delta: Tensor | None, + ) -> tuple[Tensor, int, Tensor | None, Tensor | None]: + if self.parallel_v2_has_delayed_merge and pending_attn_delta is not None: + if pending_second_delta is None: + raise RuntimeError("Missing delayed Parallel v2 second-lane tensor") + x = x + pending_attn_delta + pending_second_delta + if ple_apply is not None and self.args.ple_mode == "post_attn" and self.args.parallel_v2_use_post_attn_ple: + x = ple_apply(x, ple_all, pending_layer_idx) + if ple_apply is not None and self.args.ple_mode == "post_ffn": + x = ple_apply(x, ple_all, pending_layer_idx) + return x, -1, None, None + return x, pending_layer_idx, pending_attn_delta, pending_second_delta + + def _forward_hidden(self, input_ids: Tensor) -> Tensor: + last_v_for_cross_layer_v: Tensor | None = None + last_k_for_kv_sharing: Tensor | None = None + last_v_for_kv_sharing: Tensor | None = None + v0_global: Tensor | None = None + pending_parallel_v2_layer_idx = -1 + pending_parallel_v2_attn_delta: Tensor | None = None + pending_parallel_v2_second_delta: Tensor | None = None + + x = self.tok_emb(input_ids) + ple_all = self.ple.build_all(input_ids, x) + ple_apply = self.ple.apply if self.ple.enabled else None + + if getattr(self, "ngram", None) is not None: + scale = self.ngram_global_scale_buf.to(dtype=x.dtype) + x = x + scale * self.ngram(input_ids) + + x = F.rms_norm(x, (x.size(-1),)) + if self.smear_mode: + x = self.local_mix(x) + if self.smear_gate: + x = self.smear_gate_module(x) + x0 = x + + if self.depth_recur_enabled: + stem_skips: list[Tensor] = [] + for i in range(self.num_stem_blocks): + shared_k = last_k_for_kv_sharing if self._should_share_from_prev_layer(i) else None + shared_v = last_v_for_kv_sharing if self._should_share_from_prev_layer(i) else None + x, k_eff, v_eff, raw_v, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._run_block( + self.blocks[i], i, x, x0, ple_all, ple_apply, shared_k, shared_v, + last_v_for_cross_layer_v, v0_global, + pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta, + ) + if v0_global is None: + v0_global = raw_v + last_v_for_cross_layer_v = v_eff + last_k_for_kv_sharing = k_eff + last_v_for_kv_sharing = v_eff + stem_skips.append(x) + + core_start = self.num_stem_blocks + for _ in range(self.num_core_repeats): + for j in range(self.num_core_blocks): + block_idx = core_start + j + shared_k = last_k_for_kv_sharing if self._should_share_from_prev_layer(block_idx) else None + shared_v = last_v_for_kv_sharing if self._should_share_from_prev_layer(block_idx) else None + x, k_eff, v_eff, raw_v, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._run_block( + self.blocks[block_idx], block_idx, x, x0, ple_all, ple_apply, shared_k, shared_v, + last_v_for_cross_layer_v, v0_global, + pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta, + ) + if v0_global is None: + v0_global = raw_v + last_v_for_cross_layer_v = v_eff + last_k_for_kv_sharing = k_eff + last_v_for_kv_sharing = v_eff + + tail_start = self.num_stem_blocks + self.num_core_blocks + for i in range(self.num_tail_blocks): + skip_x = stem_skips[self.num_stem_blocks - 1 - i] + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skip_x + block_idx = tail_start + i + shared_k = last_k_for_kv_sharing if self._should_share_from_prev_layer(block_idx) else None + shared_v = last_v_for_kv_sharing if self._should_share_from_prev_layer(block_idx) else None + x, k_eff, v_eff, raw_v, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._run_block( + self.blocks[block_idx], block_idx, x, x0, ple_all, ple_apply, shared_k, shared_v, + last_v_for_cross_layer_v, v0_global, + pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta, + ) + if v0_global is None: + v0_global = raw_v + last_v_for_cross_layer_v = v_eff + last_k_for_kv_sharing = k_eff + last_v_for_kv_sharing = v_eff + + elif self.bifpn2_mode: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + shared_k = last_k_for_kv_sharing if self._should_share_from_prev_layer(i) else None + shared_v = last_v_for_kv_sharing if self._should_share_from_prev_layer(i) else None + x, k_eff, v_eff, raw_v, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._run_block( + self.blocks[i], i, x, x0, ple_all, ple_apply, shared_k, shared_v, + last_v_for_cross_layer_v, v0_global, + pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta, + ) + if v0_global is None: + v0_global = raw_v + skips.append(x) + last_v_for_cross_layer_v = v_eff + last_k_for_kv_sharing = k_eff + last_v_for_kv_sharing = v_eff + + for i in range(self.num_decoder_layers): + fusion_feature = self.structured_bifpn(skips=skips, decoder_idx=i, x_dtype=x.dtype) + x = x + fusion_feature + block_idx = self.num_encoder_layers + i + shared_k = last_k_for_kv_sharing if self._should_share_from_prev_layer(block_idx) else None + shared_v = last_v_for_kv_sharing if self._should_share_from_prev_layer(block_idx) else None + x, k_eff, v_eff, raw_v, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._run_block( + self.blocks[block_idx], block_idx, x, x0, ple_all, ple_apply, shared_k, shared_v, + last_v_for_cross_layer_v, v0_global, + pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta, + ) + if v0_global is None: + v0_global = raw_v + last_v_for_cross_layer_v = v_eff + last_k_for_kv_sharing = k_eff + last_v_for_kv_sharing = v_eff + + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x, k_eff, v_eff, raw_v, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._run_block( + self.blocks[i], i, x, x0, ple_all, ple_apply, None, None, None, v0_global, + pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta, + ) + if v0_global is None: + v0_global = raw_v + skips.append(x) + + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + block_idx = self.num_encoder_layers + i + x, k_eff, v_eff, raw_v, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._run_block( + self.blocks[block_idx], block_idx, x, x0, ple_all, ple_apply, None, None, None, v0_global, + pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta, + ) + if v0_global is None: + v0_global = raw_v + + x, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._merge_parallel_v2_pending( + x, + ple_all, + ple_apply, + pending_parallel_v2_layer_idx, + pending_parallel_v2_attn_delta, + pending_parallel_v2_second_delta, + ) + x = self.final_norm(x) + return x + + def _project_logits_from_hidden(self, x: Tensor) -> Tensor: + B, T, D = x.shape + x_flat = x.reshape(-1, D) + + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + if self.scaledlm_head: + logits_proj = logits_proj / math.sqrt(D) + else: + logits_proj = self.lm_head(x_flat) + if self.scaledlm_head: + logits_proj = logits_proj / math.sqrt(D) + + if not self.training or getattr(self, "_log_logits", False): + self.max_logit_pre_cap = logits_proj.detach().abs().max() + + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return logits.view(B, T, -1) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self._forward_hidden(input_ids) + return self._project_logits_from_hidden(x) + + def _run_block( + self, + block: Block, + block_idx: int, + x: Tensor, + x0: Tensor, + ple_all: Tensor | None, + ple_apply, + shared_k: Tensor | None, + shared_v: Tensor | None, + prev_v: Tensor | None, + v0: Tensor | None, + pending_layer_idx: int, + pending_attn_delta: Tensor | None, + pending_second_delta: Tensor | None, + ): + x, pending_layer_idx, pending_attn_delta, pending_second_delta = self._merge_parallel_v2_pending( + x, + ple_all, + ple_apply, + pending_layer_idx, + pending_attn_delta, + pending_second_delta, + ) + out = block( + x, + x0, + ple_all=ple_all, + ple_apply=ple_apply, + shared_k=shared_k, + shared_v=shared_v, + prev_v=prev_v, + v0=v0, + ) + if self.parallel_v2_has_delayed_merge and block.parallel_v2_delayed_merge: + x, k_eff, v_eff, raw_v, attn_out, second_out = out + attn_delta, second_delta = block._parallel_v2_scaled_lanes(attn_out, second_out, x.dtype) + return x, k_eff, v_eff, raw_v, block_idx, attn_delta, second_delta + x, k_eff, v_eff, raw_v = out + return x, k_eff, v_eff, raw_v, pending_layer_idx, pending_attn_delta, pending_second_delta + + def forward(self, input_ids: Tensor, target_ids: Tensor, reduction: str = "mean", ngram_global_scale: float = 1.0): + x = self._forward_hidden(input_ids) + x_original = x + B, T, D = x.shape + + logits = self._project_logits_from_hidden(x) + targets = target_ids.reshape(-1) + + if reduction == "none": + loss_flat = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + targets, + reduction="none", + ) + loss_tokens = loss_flat.view(B, T) + return loss_tokens.mean(), loss_tokens + + main_loss = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + targets, + reduction="mean", + ) + + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + mtp_loss_sum = x_original.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = T - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x_original[:, :valid_t, :].reshape(-1, D) + mtp_targets = target_ids[:, k + 1:].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy( + mtp_logits.float(), mtp_targets, reduction="mean" + ) + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + + return main_loss + + # def forward(self, input_ids: Tensor, target_ids: Tensor, reduction: str = "mean", ngram_global_scale: float = 1.0): + # last_v_for_cross_layer_v: Tensor | None = None + # last_k_for_kv_sharing: Tensor | None = None + # last_v_for_kv_sharing: Tensor | None = None + # v0_global: Tensor | None = None + # pending_parallel_v2_layer_idx = -1 + # pending_parallel_v2_attn_delta: Tensor | None = None + # pending_parallel_v2_second_delta: Tensor | None = None + + # x = self.tok_emb(input_ids) + # ple_all = self.ple.build_all(input_ids, x) + # ple_apply = self.ple.apply if self.ple.enabled else None + + # if getattr(self, "ngram", None) is not None: + # scale = self.ngram_global_scale_buf.to(dtype=x.dtype) + # x = x + scale * self.ngram(input_ids) + + # x = F.rms_norm(x, (x.size(-1),)) + # if self.smear_mode: + # x = self.local_mix(x) + # if self.smear_gate: + # x = self.smear_gate_module(x) + # x0 = x + + # if self.depth_recur_enabled: + # stem_skips: list[Tensor] = [] + # for i in range(self.num_stem_blocks): + # shared_k = last_k_for_kv_sharing if self._should_share_from_prev_layer(i) else None + # shared_v = last_v_for_kv_sharing if self._should_share_from_prev_layer(i) else None + # x, k_eff, v_eff, raw_v, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._run_block( + # self.blocks[i], i, x, x0, ple_all, ple_apply, shared_k, shared_v, + # last_v_for_cross_layer_v, v0_global, + # pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta, + # ) + # if v0_global is None: + # v0_global = raw_v + # last_v_for_cross_layer_v = v_eff + # last_k_for_kv_sharing = k_eff + # last_v_for_kv_sharing = v_eff + # stem_skips.append(x) + + # core_start = self.num_stem_blocks + # for _ in range(self.num_core_repeats): + # for j in range(self.num_core_blocks): + # block_idx = core_start + j + # shared_k = last_k_for_kv_sharing if self._should_share_from_prev_layer(block_idx) else None + # shared_v = last_v_for_kv_sharing if self._should_share_from_prev_layer(block_idx) else None + # x, k_eff, v_eff, raw_v, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._run_block( + # self.blocks[block_idx], block_idx, x, x0, ple_all, ple_apply, shared_k, shared_v, + # last_v_for_cross_layer_v, v0_global, + # pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta, + # ) + # if v0_global is None: + # v0_global = raw_v + # last_v_for_cross_layer_v = v_eff + # last_k_for_kv_sharing = k_eff + # last_v_for_kv_sharing = v_eff + + # tail_start = self.num_stem_blocks + self.num_core_blocks + # for i in range(self.num_tail_blocks): + # skip_x = stem_skips[self.num_stem_blocks - 1 - i] + # x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skip_x + # block_idx = tail_start + i + # shared_k = last_k_for_kv_sharing if self._should_share_from_prev_layer(block_idx) else None + # shared_v = last_v_for_kv_sharing if self._should_share_from_prev_layer(block_idx) else None + # x, k_eff, v_eff, raw_v, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._run_block( + # self.blocks[block_idx], block_idx, x, x0, ple_all, ple_apply, shared_k, shared_v, + # last_v_for_cross_layer_v, v0_global, + # pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta, + # ) + # if v0_global is None: + # v0_global = raw_v + # last_v_for_cross_layer_v = v_eff + # last_k_for_kv_sharing = k_eff + # last_v_for_kv_sharing = v_eff + + # elif self.bifpn2_mode: + # skips: list[Tensor] = [] + # for i in range(self.num_encoder_layers): + # shared_k = last_k_for_kv_sharing if self._should_share_from_prev_layer(i) else None + # shared_v = last_v_for_kv_sharing if self._should_share_from_prev_layer(i) else None + # x, k_eff, v_eff, raw_v, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._run_block( + # self.blocks[i], i, x, x0, ple_all, ple_apply, shared_k, shared_v, + # last_v_for_cross_layer_v, v0_global, + # pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta, + # ) + # if v0_global is None: + # v0_global = raw_v + # skips.append(x) + # last_v_for_cross_layer_v = v_eff + # last_k_for_kv_sharing = k_eff + # last_v_for_kv_sharing = v_eff + + # for i in range(self.num_decoder_layers): + # fusion_feature = self.structured_bifpn(skips=skips, decoder_idx=i, x_dtype=x.dtype) + # x = x + fusion_feature + # block_idx = self.num_encoder_layers + i + # shared_k = last_k_for_kv_sharing if self._should_share_from_prev_layer(block_idx) else None + # shared_v = last_v_for_kv_sharing if self._should_share_from_prev_layer(block_idx) else None + # x, k_eff, v_eff, raw_v, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._run_block( + # self.blocks[block_idx], block_idx, x, x0, ple_all, ple_apply, shared_k, shared_v, + # last_v_for_cross_layer_v, v0_global, + # pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta, + # ) + # if v0_global is None: + # v0_global = raw_v + # last_v_for_cross_layer_v = v_eff + # last_k_for_kv_sharing = k_eff + # last_v_for_kv_sharing = v_eff + + # else: + # skips: list[Tensor] = [] + # for i in range(self.num_encoder_layers): + # x, k_eff, v_eff, raw_v, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._run_block( + # self.blocks[i], i, x, x0, ple_all, ple_apply, None, None, None, v0_global, + # pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta, + # ) + # if v0_global is None: + # v0_global = raw_v + # skips.append(x) + # for i in range(self.num_decoder_layers): + # if skips: + # x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + # block_idx = self.num_encoder_layers + i + # x, k_eff, v_eff, raw_v, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._run_block( + # self.blocks[block_idx], block_idx, x, x0, ple_all, ple_apply, None, None, None, v0_global, + # pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta, + # ) + # if v0_global is None: + # v0_global = raw_v + + # x, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._merge_parallel_v2_pending( + # x, + # ple_all, + # ple_apply, + # pending_parallel_v2_layer_idx, + # pending_parallel_v2_attn_delta, + # pending_parallel_v2_second_delta, + # ) + # x = self.final_norm(x) + # x_original = x + # x = x.reshape(-1, x.size(-1)) + # targets = target_ids.reshape(-1) + + # if self.tie_embeddings: + # logits_proj = F.linear(x, self.tok_emb.weight) + # if self.scaledlm_head: + # logits_proj = logits_proj / math.sqrt(x.size(-1)) + # else: + # logits_proj = self.lm_head(x) + # if self.scaledlm_head: + # logits_proj = logits_proj / math.sqrt(x.size(-1)) + + # if not self.training or getattr(self, "_log_logits", False): + # self.max_logit_pre_cap = logits_proj.detach().abs().max() + + # logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + # if reduction == "none": + # loss_flat = F.cross_entropy(logits.float(), targets, reduction="none") + # loss_tokens = loss_flat.view(input_ids.shape[0], input_ids.shape[1]) + # return loss_tokens.mean(), loss_tokens + + # main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + + # if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + # _, seqlen, dim = x_original.shape + # mtp_loss_sum = x_original.new_zeros(()) + # mtp_loss_count = 0 + # for k, mtp_head in enumerate(self.mtp_heads): + # valid_t = seqlen - (k + 1) + # if valid_t <= 0: + # continue + # mtp_hidden = x_original[:, :valid_t, :].reshape(-1, dim) + # mtp_targets = target_ids[:, k + 1:].reshape(-1) + # mtp_logits_proj = mtp_head(mtp_hidden) + # mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + # mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + # mtp_loss_count += 1 + # if mtp_loss_count > 0: + # main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + + # return main_loss + + +# ============================================================ +# DATA STREAMING +# ============================================================ + +class TokenStream: + def __init__(self, pattern: str): + self.files = [Path(p) for p in sorted(glob.glob(pattern))] + if not self.files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + self.file_idx = 0 + self.tokens = load_data_shard(self.files[0]) + self.pos = 0 + + def _advance_file(self): + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos: self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int): + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start: start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ============================================================ +# EVAL +# ============================================================ + +@torch.no_grad() +def eval_val(args: Hyperparameters, model: nn.Module, rank: int, world_size: int, device: torch.device, grad_accum_steps: int, val_tokens: Tensor, base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor): + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError("VAL_BATCH_SIZE too small") + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +@torch.no_grad() +def eval_val_sliding(args: Hyperparameters, model: nn.Module, rank: int, world_size: int, device: torch.device, grad_accum_steps: int, val_tokens: Tensor, base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor): + model.eval() + seq_len = args.train_seq_len + stride = args.eval_stride + batch_seqs = args.eval_batch_seqs + + total_loss_sum = torch.zeros(1, device=device, dtype=torch.float64) + total_token_count = torch.zeros(1, device=device, dtype=torch.float64) + total_byte_count = torch.zeros(1, device=device, dtype=torch.float64) + + max_start = val_tokens.numel() - 1 - seq_len + starts = list(range(0, max_start + 1, stride)) + if starts[-1] != max_start: + starts.append(max_start) + starts = starts[rank::world_size] + + def _score_batch(xb: Tensor, yb: Tensor): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + _, loss_tokens = model(xb, yb, reduction="none") + score_loss = loss_tokens[:, -stride:] + token_count = torch.tensor(score_loss.numel(), device=device, dtype=torch.float64) + scored_y = yb[:, -stride:] + scored_x = xb[:, -stride:] + byte_count = tokens_to_bytes_count(scored_x, scored_y, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut).to(torch.float64) + loss_sum = score_loss.sum(dtype=torch.float64) + return loss_sum, token_count, byte_count + + batch_x = [] + batch_y = [] + for start in starts: + chunk = val_tokens[start: start + seq_len + 1].to(device=device, dtype=torch.int64, non_blocking=True) + batch_x.append(chunk[:-1]) + batch_y.append(chunk[1:]) + if len(batch_x) == batch_seqs: + xb = torch.stack(batch_x) + yb = torch.stack(batch_y) + ls, tc, bc = _score_batch(xb, yb) + total_loss_sum += ls + total_token_count += tc + total_byte_count += bc + batch_x.clear() + batch_y.clear() + if batch_x: + xb = torch.stack(batch_x) + yb = torch.stack(batch_y) + ls, tc, bc = _score_batch(xb, yb) + total_loss_sum += ls + total_token_count += tc + total_byte_count += bc + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(total_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(total_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(total_byte_count, op=dist.ReduceOp.SUM) + + val_loss = (total_loss_sum / total_token_count).item() + val_bpb = (total_loss_sum / (math.log(2.0) * total_byte_count)).item() + model.train() + return val_loss, val_bpb + + +class BytePPM: + """ + Lightweight order-k byte PPM-style model for eval-time mixture experiments. + This is intentionally simple and readable, not a hyper-optimized PPM-D clone. + + The first version used tuple contexts and Counter objects. That is convenient, + but it creates millions of small Python objects on multi-million-token evals. + This version packs short byte contexts into integer keys and keeps single-byte + continuations in one integer until a context needs to branch. + """ + + def __init__(self, order: int = 5, max_contexts: int = 0): + self.order = max(0, int(order)) + self.max_contexts = max(0, int(max_contexts)) + self._ctx_shift = 8 * max(1, self.order) + self.ctx_counts: dict[int, int | dict[int, int]] = {} + self.ctx_totals: dict[int, int] = {} + self.skipped_new_contexts = 0 + + @property + def context_count(self) -> int: + return len(self.ctx_counts) + + @staticmethod + def _pack_single(byte: int, count: int) -> int: + return (int(count) << 8) | int(byte) + + @staticmethod + def _unpack_single(entry: int) -> tuple[int, int]: + return entry & 0xFF, entry >> 8 + + def _ctx_key(self, history: bytearray, k: int) -> int: + packed = 0 + if k > 0: + start = len(history) - k + for i in range(start, len(history)): + packed = (packed << 8) | int(history[i]) + return (int(k) << self._ctx_shift) | packed + + def _entry_true_prob_and_conf( + self, + entry: int | dict[int, int], + total: int, + true_byte: int, + ) -> tuple[float, float]: + if total <= 0: + return 0.0, 0.0 + if isinstance(entry, int): + byte, count = self._unpack_single(entry) + p_true = float(count) / float(total) if byte == true_byte else 0.0 + return p_true, float(count) / float(total) + + true_count = entry.get(true_byte, 0) + max_count = max(entry.values()) if entry else 0 + return float(true_count) / float(total), float(max_count) / float(total) + + def predict_true_and_conf(self, history: bytearray, true_byte: int) -> tuple[float, float]: + max_k = min(self.order, len(history)) + for k in range(max_k, -1, -1): + key = self._ctx_key(history, k) + entry = self.ctx_counts.get(key) + if entry is None: + continue + return self._entry_true_prob_and_conf( + entry, + self.ctx_totals.get(key, 0), + int(true_byte), + ) + return 1.0 / 256.0, 1.0 / 256.0 + + def _add_count(self, key: int, next_byte: int) -> None: + entry = self.ctx_counts.get(key) + if entry is None: + if self.max_contexts > 0 and len(self.ctx_counts) >= self.max_contexts and key != 0: + self.skipped_new_contexts += 1 + return + self.ctx_counts[key] = self._pack_single(next_byte, 1) + self.ctx_totals[key] = 1 + return + + self.ctx_totals[key] += 1 + if isinstance(entry, int): + old_byte, old_count = self._unpack_single(entry) + if old_byte == next_byte: + self.ctx_counts[key] = self._pack_single(next_byte, old_count + 1) + else: + self.ctx_counts[key] = {old_byte: old_count, int(next_byte): 1} + else: + entry[int(next_byte)] = entry.get(int(next_byte), 0) + 1 + + def update(self, history: bytearray, next_byte: int) -> None: + max_k = min(self.order, len(history)) + for k in range(0, max_k + 1): + self._add_count(self._ctx_key(history, k), int(next_byte)) + + +def nn_byte_true_prob_for_token( + p_token_target: float, + token_byte_len: int, + floor: float = 1e-6, +) -> float: + """ + True-byte probability for the spread-root token->byte projection. + The full 256-way distribution is never materialized during PPM eval. + """ + if token_byte_len <= 0: + return 0.0 + L = int(token_byte_len) + p_target = max(min(p_token_target, 1.0), 1e-12) ** (1.0 / L) + peak = max(float(p_target), float(floor)) + return peak / (peak + 255.0 * float(floor)) + + +def mix_true_byte_prob( + nn_p_true: float, + ppm_p_true: float, + ppm_conf: float, + conf_threshold: float, + lambda_lo: float, + lambda_hi: float, +) -> tuple[float, float]: + lam = lambda_lo if ppm_conf >= conf_threshold else lambda_hi + return (1.0 - lam) * nn_p_true + lam * ppm_p_true, lam + + +@torch.no_grad() +def eval_val_with_ppm_mixture( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + token_piece_bytes: list[bytes], + has_leading_space: list[bool], + is_boundary_token: list[bool], + log0=print, +): + """ + Experimental byte-level mixture eval: + neural token LM + lightweight byte-level PPM. + Single-rank only for now. + """ + if world_size != 1: + raise NotImplementedError("PPM mixture eval currently supports single-rank only.") + + if args.nn_byte_projection != "spread_root": + raise ValueError(f"Unsupported NN_BYTE_PROJECTION: {args.nn_byte_projection}") + + model.eval() + + seq_len = args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + subset_limit = args.ppm_subset_tokens if args.ppm_subset_tokens > 0 else total_seqs * seq_len + + ppm = BytePPM(order=args.ppm_order, max_contexts=args.ppm_max_contexts) + history = bytearray() + + total_nll = 0.0 + total_bytes = 0 + total_tokens_seen = 0 + + for seq_idx in range(total_seqs): + if total_tokens_seen >= subset_limit: + break + + raw_start = seq_idx * seq_len + raw_end = raw_start + seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + + x = local[:-1].reshape(1, seq_len) + y = local[1:].reshape(1, seq_len) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = model.forward_logits(x) + + probs = torch.softmax(logits.float(), dim=-1)[0] # [T, V] + target_probs = probs.gather(1, y[0].unsqueeze(1)).squeeze(1).detach().cpu().tolist() + x_ids = x[0].tolist() + y_ids = y[0].tolist() + + for t, (prev_id, tgt_id) in enumerate(zip(x_ids, y_ids)): + if total_tokens_seen >= subset_limit: + break + + p_tok = float(target_probs[t]) + tok_bytes = reconstruct_token_bytes( + prev_token_id=prev_id, + token_id=tgt_id, + token_piece_bytes=token_piece_bytes, + has_leading_space=has_leading_space, + is_boundary_token=is_boundary_token, + ) + + nn_p_true = nn_byte_true_prob_for_token( + p_token_target=p_tok, + token_byte_len=len(tok_bytes), + floor=args.nn_byte_uniform_floor, + ) + + for true_b in tok_bytes: + ppm_p_true, ppm_conf = ppm.predict_true_and_conf(history, true_b) + p_true, _ = mix_true_byte_prob( + nn_p_true=nn_p_true, + ppm_p_true=ppm_p_true, + ppm_conf=ppm_conf, + conf_threshold=args.ppm_conf_threshold, + lambda_lo=args.lambda_lo, + lambda_hi=args.lambda_hi, + ) + + total_nll += -math.log(max(p_true, 1e-12)) + total_bytes += 1 + + # update AFTER scoring + ppm.update(history, true_b) + history.append(true_b) + + total_tokens_seen += 1 + + if seq_idx > 0 and seq_idx % 500 == 0: + mix_bpb = total_nll / (math.log(2.0) * max(total_bytes, 1)) + log0( + f"ppm_mix_progress seq:{seq_idx}/{total_seqs} " + f"tokens:{total_tokens_seen} bytes:{total_bytes} " + f"contexts:{ppm.context_count} skipped_ctx:{ppm.skipped_new_contexts} " + f"mix_bpb:{mix_bpb:.6f}" + ) + + mix_bpb = total_nll / (math.log(2.0) * max(total_bytes, 1)) + return mix_bpb + +# ============================================================ +# LEGAL SCORE-FIRST LoRA-TTT +# ============================================================ + +def build_ttt_optimizer(args: Hyperparameters, lora_mgr: LoRATTTManager): + params = list(lora_mgr.lora_parameters()) + if not params: + raise RuntimeError("No LoRA params found for TTT") + return torch.optim.SGD( + params, + lr=args.ttt_lr, + momentum=args.ttt_momentum, + weight_decay=args.ttt_weight_decay, + ) + + +def eval_val_sliding_lora_ttt( + args: Hyperparameters, + base_model: nn.Module, + lora_mgr: LoRATTTManager, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + log0=print, +): + """ + Legal score-first TTT. + + Phase 1: score a chunk under inference_mode. + Phase 2: update LoRA on that already-scored chunk. + Last chunk is never trained after scoring. + """ + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + window_starts = list(range(0, total_tokens - seq_len + 1, seq_len)) + + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + ci = min(ws // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + lora_mgr.set_enabled(True) + optimizer = build_ttt_optimizer(args, lora_mgr) + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), args.ttt_batch_seqs): + batch_ws = my_windows[bi: bi + args.ttt_batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + valid_lens = [] + for i_w, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + tok = val_tokens[ws: end + 1].to(dtype=torch.int64, device=device) + x_batch[i_w, :wlen] = tok[:-1] + y_batch[i_w, :wlen] = tok[1:] + valid_lens.append(wlen) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + _, nll = base_model(x_batch, y_batch, reduction="none") + for i_w, wlen in enumerate(valid_lens): + scored_nll = nll[i_w, :wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen) + tgt = y_batch[i_w, :wlen] + prev = x_batch[i_w, :wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + is_last = ci == num_chunks - 1 + if not is_last and args.ttt_epochs > 0: + lora_mgr.reset_chunk_state() + base_model.train() + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg["lr"] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_seqs) + start_tok = chunk_start + (my_seq_s + bs) * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok: end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in lora_mgr.lora_parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(list(lora_mgr.lora_parameters()), args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 20 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) + log0(f" lora_ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + lora_mgr.set_enabled(False) + base_model.eval() + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + log0(f"lora_ttt:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} elapsed={time.perf_counter()-t0:.1f}s") + return val_loss, val_bpb + + +# ============================================================ +# QUANTIZATION (minimal placeholder: keep your v4 implementation) +# ============================================================ + +# ============================================================ +# INT8 + ZLIB SERIALIZATION +# ============================================================ + +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + p for p in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") if p +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +DYNAMIC_CLIP_Q_LIST = [ + float(p) / 100.0 + for p in os.environ.get( + "DYNAMIC_CLIP_PERCENTILES", + "100.0,99.9999,99.9995,99.995,99.99,99.95,99.9,99.8", + ).split(",") + if p.strip() +] + + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_scale, best_mse = None, None, float("inf") + for q_pct in DYNAMIC_CLIP_Q_LIST: + if q_pct >= 1.0: + clip_abs = t32.abs().max(dim=1).values + else: + clip_abs = ( + torch.quantile(t32.abs(), q_pct, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32, device=t32.device) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8) + mse = F.mse_loss(q.float() * scale[:, None], t32).item() + if best_q is None or mse < best_mse: + best_mse, best_q, best_scale = mse, q, scale + return best_q.contiguous(), best_scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", + "baseline_tensor_bytes", "int8_payload_bytes"), 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def dequantize_state_dict_int8(obj: dict) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + out[name] = (q.float() * float(s.item())).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def compute_submission_size(state_dict: dict[str, Tensor], code: str) -> tuple[int, int, dict]: + """Return (quant_file_bytes, code_bytes, stats) without writing to disk.""" + quant_obj, stats = quantize_state_dict_int8(state_dict) + buf = io.BytesIO() + torch.save(quant_obj, buf) + blob = zlib.compress(buf.getvalue(), level=9) + code_bytes = len(code.encode("utf-8")) + return len(blob), code_bytes, stats + + +def serialize_model(base_model: nn.Module, output_dir: str, code: str, log0) -> tuple[str, dict]: + """Save final_model.pt (raw float) + final_model.int8.ptz (int8+zlib), log sizes.""" + # Raw checkpoint — used by eval_only_main and for debugging. + raw_path = os.path.join(output_dir, "final_model.pt") + torch.save(base_model.state_dict(), raw_path) + log0(f"saved raw checkpoint: {raw_path} ({os.path.getsize(raw_path):,} bytes)") + + # Quantized + compressed artifact. + quant_obj, stats = quantize_state_dict_int8(base_model.state_dict()) + buf = io.BytesIO() + torch.save(quant_obj, buf) + blob = zlib.compress(buf.getvalue(), level=9) + quant_path = os.path.join(output_dir, "final_model.int8.ptz") + with open(quant_path, "wb") as f: + f.write(blob) + quant_file_bytes = os.path.getsize(quant_path) + code_bytes = len(code.encode("utf-8")) + ratio = stats["baseline_tensor_bytes"] / max(stats["int8_payload_bytes"], 1) + limit_bytes = 16 * 1024 * 1024 # 16 MB + total_bytes = quant_file_bytes + code_bytes + log0( + f"model_size int8+zlib:{quant_file_bytes} bytes " + f"code:{code_bytes} bytes total:{total_bytes} bytes " + f"limit:16MB({limit_bytes}) " + f"{'FITS' if total_bytes <= limit_bytes else 'EXCEEDS_LIMIT'}" + ) + log0( + f" payload:{stats['int8_payload_bytes']} " + f"raw_torch:{buf.tell()} compression_ratio:{ratio:.2f}x" + ) + return quant_path, stats + + +# ============================================================ +# MAIN +# ============================================================ +def assert_all_params_on_device(module: nn.Module, device: torch.device) -> None: + bad = [] + for name, p in module.named_parameters(): + if p.device != device: + bad.append((name, str(p.device), str(p.dtype), tuple(p.shape))) + if bad: + lines = ["Parameters on wrong device:"] + for name, dev, dtype, shape in bad[:100]: + lines.append(f" {name}: device={dev} dtype={dtype} shape={shape}") + raise RuntimeError("\n".join(lines)) + +def is_lora_param(name: str) -> bool: + name = name.lower() + return ("lora_" in name) or (".lora." in name) or ("lora_a" in name) or ("lora_b" in name) + + +def collect_parallel_v2_telemetry(base_model: GPT) -> dict: + args = base_model.args + layers = [] + total_second_lane_params = 0 + for i, block in enumerate(base_model.blocks): + if not getattr(block, "parallel_v2_enabled", False): + continue + second_lane = block.mlp if block.parallel_v2_second_lane_name == "mlp" else block.second_lane + second_lane_params = count_trainable_params(second_lane) + total_second_lane_params += second_lane_params + layer = { + "layer": i, + "mode": block.parallel_v2_mode, + "second_lane": block.parallel_v2_second_lane_name, + "second_lane_params": second_lane_params, + "attn_scale_mean": float(block.parallel_v2_attn_scale.detach().float().mean().item()), + "second_scale_mean": float(block.parallel_v2_second_scale.detach().float().mean().item()), + } + if hasattr(block, "parallel_v2_gate"): + layer["gate_mean"] = float(torch.sigmoid(block.parallel_v2_gate.detach().float()).mean().item()) + if hasattr(block, "parallel_v2_attn_norm_ratio"): + attn_ratio = block.parallel_v2_attn_norm_ratio.detach().float() + second_ratio = block.parallel_v2_second_norm_ratio.detach().float() + if bool(torch.isfinite(attn_ratio).item()): + layer["attn_norm_ratio"] = float(attn_ratio.item()) + if bool(torch.isfinite(second_ratio).item()): + layer["second_norm_ratio"] = float(second_ratio.item()) + layers.append(layer) + return { + "parallel_v2_enabled": int(args.parallel_v2_enabled), + "parallel_v2_mode": args.parallel_v2_mode, + "parallel_v2_second_lane": args.parallel_v2_second_lane, + "parallel_v2_last_n_layers": args.parallel_v2_last_n_layers, + "parallel_v2_active_layers": [layer["layer"] for layer in layers], + "parallel_v2_second_lane_params_total": total_second_lane_params, + "parallel_v2_layers": layers, + } + + +def write_jsonl(path: str, payload: dict) -> None: + parent = os.path.dirname(path) + if parent: + os.makedirs(parent, exist_ok=True) + with open(path, "a", encoding="utf-8") as f: + f.write(json.dumps(payload, sort_keys=True) + "\n") + + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + grad_accum_steps = int(os.environ.get("GRAD_ACCUM_STEPS", 8 // max(world_size, 1))) + grad_scale = 1.0 / grad_accum_steps + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + outdir = args.output_dir or "logs" + os.makedirs(outdir, exist_ok=True) + logfile = os.path.join(outdir, f"{time.strftime('%Y%m%d_%H%M%S')}.txt") + print(logfile) + + def log0(msg: str, console: bool = True): + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"SentencePiece .model expected: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"VOCAB_SIZE={args.vocab_size} != tokenizer vocab_size={int(sp.vocab_size())}") + + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + + token_piece_bytes, sp_has_leading_space, sp_is_boundary_token = build_sentencepiece_byte_tables( + sp, args.vocab_size + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # base_model = GPT(args, master_process=master_process).to(device).bfloat16() + # for module in base_model.modules(): + # if isinstance(module, CastedLinear): + # module.float() + # restore_low_dim_params_to_fp32(base_model) + + # # Inject LoRA wrappers BEFORE torch.compile. + # lora_mgr = None + # if args.ttt_enabled and args.ttt_mode == "lora" and args.lora_ttt_enabled: + # lora_mgr = LoRATTTManager(base_model, args) + # lora_mgr.inject() + # if master_process: + # print(f"TTT: LoRA adapters injected for targets={args.lora_ttt_targets}") + + base_model = GPT(args, master_process=master_process) + + # Inject LoRA wrappers BEFORE moving model to device / dtype and BEFORE torch.compile. + lora_mgr = None + if args.ttt_enabled and args.ttt_mode == "lora" and args.lora_ttt_enabled: + lora_mgr = LoRATTTManager(base_model, args) + lora_mgr.inject() + if master_process: + print(f"TTT: LoRA adapters injected for targets={args.lora_ttt_targets}") + + base_model = base_model.to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + assert_all_params_on_device(base_model, device) + + + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 + and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + and not is_lora_param(name) + and p.requires_grad + ] + + scalar_params = [ + p + for name, p in block_named_params + if (p.ndim != 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) + and not is_lora_param(name) + and p.requires_grad + ] + + lora_params = [ + p + for name, p in base_model.named_parameters() + if is_lora_param(name) + and p.requires_grad + ] + + if master_process: + print(f"lora_params:{sum(p.numel() for p in lora_params)}") + + if hasattr(base_model, "skip_weights") and base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if hasattr(base_model, "bifpn_weights") and base_model.bifpn_weights.numel() > 0: + scalar_params.append(base_model.bifpn_weights) + if hasattr(base_model, "structured_bifpn"): + scalar_params.append(base_model.structured_bifpn.weights) + if hasattr(base_model, "mtp_heads") and base_model.mtp_heads is not None: + for p in base_model.mtp_heads.parameters(): + scalar_params.append(p) + if hasattr(base_model, "ngram") and base_model.ngram is not None: + if base_model.ngram.proj is not None: + scalar_params.append(base_model.ngram.proj.weight) + scalar_params.append(base_model.ngram.ngram_scales) + if hasattr(base_model, "ple") and getattr(base_model.ple, "enabled", False): + scalar_params.extend(list(base_model.ple.parameters())) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_param_groups = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if hasattr(base_model, "ngram") and base_model.ngram is not None: + for emb in base_model.ngram.embeds: + tok_param_groups.append({"params": [emb.weight], "lr": token_lr, "base_lr": token_lr}) + + optimizer_tok = torch.optim.Adam(tok_param_groups, betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizer_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, backend_steps=args.muon_backend_steps) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0(f"tie_embeddings:{args.tie_embeddings} matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}") + log0(f"ttt_enabled:{args.ttt_enabled} ttt_mode:{args.ttt_mode} lora_ttt_enabled:{args.lora_ttt_enabled}") + parallel_v2_snapshot = collect_parallel_v2_telemetry(base_model) + log0( + f"parallel_v2_enabled:{parallel_v2_snapshot['parallel_v2_enabled']} " + f"mode:{parallel_v2_snapshot['parallel_v2_mode']} " + f"second_lane:{parallel_v2_snapshot['parallel_v2_second_lane']} " + f"active_layers:{parallel_v2_snapshot['parallel_v2_active_layers']} " + f"second_lane_params:{parallel_v2_snapshot['parallel_v2_second_lane_params_total']}" + ) + if args.parallel_v2_enabled: + for layer in parallel_v2_snapshot["parallel_v2_layers"]: + gate_msg = f" gate_mean:{layer['gate_mean']:.6f}" if "gate_mean" in layer else "" + log0( + f"parallel_v2_layer:{layer['layer']} " + f"attn_scale_mean:{layer['attn_scale_mean']:.6f} " + f"second_scale_mean:{layer['second_scale_mean']:.6f}" + f"{gate_msg} params:{layer['second_lane_params']}" + ) + if master_process and args.telemetry_every > 0: + write_jsonl(args.telemetry_file, {"event": "init", **parallel_v2_snapshot}) + if args.ttt_enabled and args.lora_ttt_enabled: + log0( + f"lora_ttt_rank:{args.lora_ttt_rank} alpha:{args.lora_ttt_alpha} " + f"warmA:{int(args.lora_ttt_warm_start_a)} resetB:{int(args.lora_ttt_reset_b_each_chunk)} " + f"chunk:{args.ttt_chunk_tokens} wd:{args.ttt_weight_decay}" + ) + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all(): + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + use_walltime_stop = args.stop_mode == "walltime" + use_steps_stop = args.stop_mode == "steps" + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if (use_walltime_stop and args.max_wallclock_seconds > 0) else None + hard_step_limit = args.max_train_steps if (use_steps_stop and args.max_train_steps > 0) else args.iterations + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if use_steps_stop: + warmdown_start = max(hard_step_limit - args.warmdown_iters, 0) + return max((hard_step_limit - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < hard_step_limit else 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * (1.0 / grad_accum_steps)).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + ema_state = None + if args.ema_enabled: + log0(f"EMA Enabled: decay={args.ema_decay}") + ema_state = {name: p.detach().float().clone() for name, p in base_model.state_dict().items()} + ema_tensors_list = list(ema_state.values()) + model_tensors_list = list(base_model.state_dict().values()) + + qat_start_step = int(hard_step_limit * (1.0 - args.late_qat_ratio)) + if args.late_qat_ratio > 0: + log0(f"Scheduled Late QAT to start at step {qat_start_step} (last {args.late_qat_ratio*100:.1f}%)") + + step = 0 + muon_momentum = args.muon_momentum_warmup_start if args.muon_momentum_warmup_steps > 0 else args.muon_momentum + + while True: + last_step = step == hard_step_limit or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + eval_fn = eval_val_sliding if args.eval_use_sliding_window else eval_val + eval_model = base_model if args.eval_use_sliding_window else model + val_loss, val_bpb = eval_fn(args, eval_model, rank, world_size, device, grad_accum_steps, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"step:{step}/{hard_step_limit} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms") + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + break + + if step == qat_start_step and args.late_qat_ratio > 0.0: + log0(f"[Step {step}] Activating Late QAT — enabling branchless STE quantization.") + for mod in base_model.modules(): + if isinstance(mod, CastedLinear): + mod.qat_alpha.fill_(1.0) + + step_t0 = time.perf_counter() + elapsed_ms = training_time_ms + 1000.0 * (step_t0 - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + last_telemetry_x: Tensor | None = None + last_telemetry_y: Tensor | None = None + + ngram_global_scale = compute_ngram_fade_scale( + step=step, + total_steps=hard_step_limit, + enabled=args.ngram_fade_enable, + start_frac=args.ngram_fade_start_frac, + end_frac=args.ngram_fade_end_frac, + min_scale=args.ngram_fade_min_scale, + ) + + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + base_model.ngram_global_scale_buf.fill_(float(ngram_global_scale)) + loss = model(x, y) + last_telemetry_x = x.detach() + last_telemetry_y = y.detach() + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + if args.ema_enabled and ema_state is not None: + with torch.no_grad(): + update_ema_fused(ema_tensors_list, model_tensors_list, args.ema_decay) + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step}/{hard_step_limit} train_loss:{train_loss.item():.4f} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms") + + if master_process and args.telemetry_every > 0 and step % args.telemetry_every == 0: + if args.parallel_v2_enabled and args.parallel_v2_log_norm_ratios and last_telemetry_x is not None and last_telemetry_y is not None: + was_training = base_model.training + base_model.set_parallel_v2_norm_capture(True) + try: + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + base_model(last_telemetry_x, last_telemetry_y) + finally: + base_model.set_parallel_v2_norm_capture(False) + base_model.train(was_training) + write_jsonl( + args.telemetry_file, + { + "event": "train", + "step": step, + "train_loss": float(train_loss.item()), + "lr_scale": float(scale), + **collect_parallel_v2_telemetry(base_model), + }, + ) + + if use_steps_stop: + reached_cap = step >= hard_step_limit + else: + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB") + + if args.ema_enabled and ema_state is not None: + log0("Applying EMA weights for final evaluation...") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + + # ----------------------------- + # SERIALIZATION + SIZE REPORT + ROUNDTRIP VALIDATION + # ----------------------------- + if master_process: + serialize_model(base_model, outdir, code, log0) + if distributed: + dist.barrier() + + # Roundtrip: reload quantized weights, run eval to measure degradation. + quant_path_rt = os.path.join(outdir, "final_model.int8.ptz") if master_process else os.path.join(args.output_dir or "logs", "final_model.int8.ptz") + with open(quant_path_rt, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if args.ppm_enabled and rank == 0: + log0("Starting PPM byte mixture evaluation...") + ppm_mix_bpb = eval_val_with_ppm_mixture( + args=args, + model=base_model, + rank=rank, + world_size=world_size, + device=device, + val_tokens=val_tokens, + token_piece_bytes=token_piece_bytes, + has_leading_space=sp_has_leading_space, + is_boundary_token=sp_is_boundary_token, + log0=log0, + ) + log0(f"ppm_mix_bpb:{ppm_mix_bpb:.6f}") + + if args.ttt_enabled and args.ttt_mode == "lora" and lora_mgr is not None: + log0("Starting legal LoRA-TTT evaluation...") + ttt_val_loss, ttt_val_bpb = eval_val_sliding_lora_ttt( + args=args, + base_model=base_model, + lora_mgr=lora_mgr, + rank=rank, + world_size=world_size, + device=device, + val_tokens=val_tokens, + base_bytes_lut=base_bytes_lut, + has_leading_space_lut=has_leading_space_lut, + is_boundary_token_lut=is_boundary_token_lut, + log0=log0, + ) + log0(f"lora_ttt_final val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f}") + log0(f"lora_ttt_final_exact val_loss:{ttt_val_loss:.8f} val_bpb:{ttt_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +def eval_only_main() -> None: + """ + Eval-only mode: load a checkpoint, report int8+zlib model size, and run + the standard val evaluation — no training. + + Usage: + EVAL_ONLY=1 CHECKPOINT=path/to/final_model.pt OUTPUT_DIR=logs/eval \ + python mytrain_gpt_v6.py + + The checkpoint must be a raw state_dict saved with torch.save(). + All model-architecture env vars must match the original training run. + """ + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + + rank = int(os.environ.get("RANK", "0")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + outdir = args.output_dir or "logs/eval" + os.makedirs(outdir, exist_ok=True) + logfile = os.path.join(outdir, f"eval_{time.strftime('%Y%m%d_%H%M%S')}.txt") + + def log0(msg: str, console: bool = True): + if not master_process: + return + if console: + print(msg) + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + checkpoint_path = os.environ.get("CHECKPOINT", "") + if not checkpoint_path: + raise ValueError("EVAL_ONLY=1 requires CHECKPOINT= to be set") + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + log0(f"eval_only checkpoint:{checkpoint_path}") + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"VOCAB_SIZE={args.vocab_size} != tokenizer vocab_size={int(sp.vocab_size())}") + + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + token_piece_bytes, sp_has_leading_space, sp_is_boundary_token = build_sentencepiece_byte_tables( + sp, args.vocab_size + ) + + base_model = GPT(args, master_process=master_process).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + + state = torch.load(checkpoint_path, map_location="cpu") + # Support both raw state_dict and {"model": state_dict} wrapping. + if "model" in state and isinstance(state["model"], dict): + state = state["model"] + base_model.load_state_dict(state, strict=True) + log0(f"loaded checkpoint: {checkpoint_path}") + + total_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{total_params:,}") + + # ---------- size report ---------- + quant_file_bytes, code_bytes, stats = compute_submission_size(base_model.state_dict(), code) + limit_bytes = 16 * 1024 * 1024 + total_bytes = quant_file_bytes + code_bytes + log0( + f"submission_size int8+zlib:{quant_file_bytes} bytes " + f"code:{code_bytes} bytes total:{total_bytes} bytes " + f"limit:16MB({limit_bytes}) " + f"{'FITS' if total_bytes <= limit_bytes else 'EXCEEDS_LIMIT'}" + ) + log0(f" params:{stats['param_count']:,} " + f"float_tensors:{stats['num_float_tensors']} " + f"baseline_bytes:{stats['baseline_tensor_bytes']} " + f"int8_payload:{stats['int8_payload_bytes']} " + f"compression:{stats['baseline_tensor_bytes']/max(stats['int8_payload_bytes'],1):.2f}x") + + # ---------- val eval (fp weights) ---------- + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + grad_accum_steps = int(os.environ.get("GRAD_ACCUM_STEPS", "1")) + t0 = time.perf_counter() + val_loss, val_bpb = eval_val( + args, compiled_model, rank, 1, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"eval_fp val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"eval_time:{1000.0*(time.perf_counter()-t0):.0f}ms" + ) + log0(f"eval_fp_exact val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f}") + + # ---------- int8 roundtrip eval ---------- + quant_obj, _ = quantize_state_dict_int8(base_model.state_dict()) + buf = io.BytesIO() + torch.save(quant_obj, buf) + blob = zlib.compress(buf.getvalue(), level=9) + quant_state = torch.load(io.BytesIO(zlib.decompress(blob)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t0 = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_model, rank, 1, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"eval_int8_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0*(time.perf_counter()-t0):.0f}ms" + ) + log0(f"eval_int8_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if args.ppm_enabled: + ppm_bpb = eval_val_with_ppm_mixture( + args=args, model=base_model, rank=rank, world_size=1, + device=device, val_tokens=val_tokens, + token_piece_bytes=token_piece_bytes, + has_leading_space=sp_has_leading_space, + is_boundary_token=sp_is_boundary_token, + log0=log0, + ) + log0(f"ppm_mix_bpb:{ppm_bpb:.6f}") + + log0(f"eval_only done logfile:{logfile}") + + +if __name__ == "__main__": + if os.environ.get("EVAL_ONLY", "0") == "1": + eval_only_main() + else: + main() + +""" +python launchv3.py sweep_ppm_mixture_v1.json \ + --train-script mytrain_gpt_v6.py \ + --output output/run_sweep_ppm_mixture_v1 \ + --stop-mode steps \ + --max-steps 3000 \ + --no-analysis +""" diff --git a/records/track_non_record_16mb/2026-04-18_SP1024_ShareVLast3_3Seed/README.md b/records/track_non_record_16mb/2026-04-18_SP1024_ShareVLast3_3Seed/README.md new file mode 100644 index 0000000000..0b2fb77122 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-18_SP1024_ShareVLast3_3Seed/README.md @@ -0,0 +1,77 @@ +# SP1024 + Shared-V(last3) 3-seed non-record submission + +This is a stable non-record 16MB submission based on the official SP1024 tokenizer and a compact transformer with structured skip fusion. + +## Summary + +This submission uses: + +- official `fineweb_1024_bpe.model` +- standard FineWeb SP1024 dataset +- structured skip fusion (`BIFPN2_MODE=1`) +- XSA on the last 4 layers +- 2-gram scaffold with fade-out +- shared V across the last 3 layers + +This submission is intended as a stable, rule-compliant baseline submission rather than a leaderboard-top attempt. + +## Representative run + +Representative seed: **2027** + +Representative exact roundtrip BPB: **1.27717259** + +Submission size: **15973626 bytes** + +## 3-seed results + +| seed | last_val_bpb | roundtrip_exact_val_bpb | submission_bytes | +|------|--------------|-------------------------|------------------| +| 1337 | 1.2791 | 1.28079096 | 15972114 | +| 2027 | 1.2755 | 1.27717259 | 15973626 | +| 3407 | 1.2779 | 1.27952108 | 15975453 | + +3-seed mean exact roundtrip BPB: **1.27916154** + +## Files + +- `submission.json`: metadata for this submission +- `train.log`: representative training log +- `train_gpt.py`: training script snapshot used for this submission +- `config.json`: resolved config for the representative run +- `seed_runs.csv`: all 3 seed results +- `requirements.txt`: minimal environment dependencies + +## Main configuration + +Key settings: + +- tokenizer: SP1024 +- `BIFPN2_MODE=1` +- `XSA_ENABLED=1` +- `XSA_LAST_N_LAYERS=4` +- `NGRAM_MAX_N=2` +- `NGRAM_FADE_ENABLE=1` +- `CROSS_LAYER_KV_SHARING_ENABLED=1` +- `CROSS_LAYER_KV_SHARE_V=1` +- `CROSS_LAYER_KV_PAIRWISE=0` +- `CROSS_LAYER_KV_PARTIAL_HEAD=0` + +## Notes + +- This submission does **not** modify the tokenizer or dataset. +- This is a reproducibility-focused non-record submission under the 16MB artifact limit. +- The representative run uses seed 2027 because it was the best run among the 3 submission seeds. + +## Reproduction + +Typical command pattern: + +```bash +python launchv3.py config_submission_sharev3_3seed.json \ + --train-script mytrain_gpt_v2_1.py \ + --output output/submission_sharev3_3seed \ + --stop-mode steps \ + --max-steps 3000 \ + --only submission_seed2027 +''' diff --git a/records/track_non_record_16mb/2026-04-18_SP1024_ShareVLast3_3Seed/config.json b/records/track_non_record_16mb/2026-04-18_SP1024_ShareVLast3_3Seed/config.json new file mode 100644 index 0000000000..ef3a92ab95 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-18_SP1024_ShareVLast3_3Seed/config.json @@ -0,0 +1,121 @@ +{ + "_comment_TRACK": "Stable non-record submission candidate under 16MB, using official SP1024 tokenizer and no tokenizer changes", + "_comment_DATA": "Official SP1024 data/tokenizer", + "DATA_PATH": "./data/datasets/fineweb10B_sp1024", + "TOKENIZER_PATH": "./data/tokenizers/fineweb_1024_bpe.model", + "VOCAB_SIZE": 1024, + "_comment_CORE": "Core model shape", + "NUM_LAYERS": 9, + "MODEL_DIM": 512, + "NUM_HEADS": 8, + "NUM_KV_HEADS": 4, + "MLP_MULT": 2, + "TIE_EMBEDDINGS": 1, + "ROPE_BASE": 10000.0, + "LOGIT_SOFTCAP": 30.0, + "QK_GAIN_INIT": 1.5, + "_comment_TRAIN": "Train schedule", + "GRAD_ACCUM_STEPS": 4, + "TRAIN_BATCH_TOKENS": 524288, + "TRAIN_SEQ_LEN": 1024, + "ITERATIONS": 20000, + "WARMUP_STEPS": 20, + "WARMDOWN_ITERS": 1200, + "STOP_MODE": "steps", + "MAX_TRAIN_STEPS": 3000, + "MAX_WALLCLOCK_SECONDS": 3600.0, + "_comment_OPTIM": "Optimizer", + "MATRIX_LR": 0.04, + "SCALAR_LR": 0.04, + "EMBED_LR": 0.6, + "HEAD_LR": 0.008, + "TIED_EMBED_LR": 0.05, + "TIED_EMBED_INIT_STD": 0.005, + "MUON_MOMENTUM": 0.95, + "MUON_BACKEND_STEPS": 5, + "MUON_MOMENTUM_WARMUP_START": 0.85, + "MUON_MOMENTUM_WARMUP_STEPS": 500, + "BETA1": 0.9, + "BETA2": 0.95, + "ADAM_EPS": 1e-08, + "GRAD_CLIP_NORM": 0.0, + "_comment_SKIP": "Best stable under-size stack", + "FDA_MODE": 0, + "BIFPN_MODE": 0, + "BIFPN2_MODE": 1, + "BIFPN_GROUP_COUNT": 8, + "BIFPN_BAND_WIDTH": 1, + "BIFPN_NORM_EPS": 0.0001, + "BIFPN_INIT_MAIN": 1.0, + "BIFPN_INIT_NEIGHBOR": 0.15, + "BIFPN_INIT_FAR": 0.0, + "_comment_STAB": "Stability toggles", + "SCALEDLM_HEAD": 1, + "SMEAR_MODE": 0, + "SMEAR_WINDOW": 4, + "SMEAR_GATE": 0, + "ROPE_DIMS": -1, + "LEARNABLE_ROPE": 0, + "LN_SCALE": 1, + "LEARNABLE_LN_SCALE": 0, + "AFFINE_NORM": 0, + "_comment_XSA": "Keep XSA on last 4 layers", + "XSA_ENABLED": 1, + "XSA_LAST_N_LAYERS": 4, + "XSA_EPS": 1e-06, + "_comment_VALUE_PATH": "Use plain shared V only; this was the under-16MB stable candidate", + "V_SKIP_ENABLED": 0, + "V_SKIP_LAST_N_LAYERS": 4, + "V_SKIP_MODE": "scalar", + "V_SKIP_GROUP_COUNT": 8, + "CROSS_LAYER_V_ENABLED": 0, + "CROSS_LAYER_V_LAST_N_LAYERS": 4, + "CROSS_LAYER_V_MODE": "residual", + "CROSS_LAYER_V_GROUP_COUNT": 4, + "_comment_MEMORY_PATH": "Share V across later layers, no K sharing", + "CROSS_LAYER_KV_SHARING_ENABLED": 1, + "CROSS_LAYER_KV_LAST_N_LAYERS": 3, + "CROSS_LAYER_KV_SHARE_K": 0, + "CROSS_LAYER_KV_SHARE_V": 1, + "CROSS_LAYER_KV_PAIRWISE": 0, + "CROSS_LAYER_KV_PARTIAL_HEAD": 0, + "CROSS_LAYER_KV_PARTIAL_HEAD_COUNT": 2, + "_comment_PLE": "Disabled for this stable submission", + "PLE_ENABLED": 0, + "PLE_TEMPORAL_CONV": 0, + "PLE_DIM": 32, + "PLE_MODE": "post_attn", + "PLE_TOKEN_SCALE_INIT": 1.0, + "PLE_CTX_SCALE_INIT": 1.0, + "PLE_RESID_SCALE_INIT": 0.01, + "_comment_MTP": "Disabled", + "MTP_NUM_HEADS": 0, + "MTP_LOSS_WEIGHT": 0.2, + "MTPHEAD_MLPMODE": 0, + "_comment_NGRAM": "Keep 2-gram scaffold + fade-out", + "NGRAM_VOCAB_SIZE": 2048, + "NGRAM_DIM": 128, + "NGRAM_MAX_N": 2, + "NGRAM_FADE_ENABLE": 1, + "NGRAM_FADE_START_FRAC": 0.15, + "NGRAM_FADE_END_FRAC": 0.45, + "NGRAM_FADE_MIN_SCALE": 0.0, + "_comment_EMA_QAT": "Keep EMA and conservative late QAT", + "EMA_ENABLED": 1, + "EMA_DECAY": 0.997, + "DYNAMIC_CLIP_PERCENTILES": "100.0,99.9999,99.9995,99.995,99.9", + "LATE_QAT_RATIO": 0.15, + "_comment_EVAL": "Submission run should use non-sliding eval for direct comparability", + "VAL_LOSS_EVERY": 1000, + "VAL_BATCH_SIZE": 524288, + "EVAL_USE_SLIDING_WINDOW": 0, + "EVAL_STRIDE": 1024, + "EVAL_BATCH_SEQS": 16, + "_comment_LOGGING": "Telemetry/logging", + "TELEMETRY_EVERY": 50, + "TRAIN_LOG_EVERY": 200, + "PROFILE_RUN": 0, + "PROFILE_WARMUP_STEPS": 5, + "PROFILE_ACTIVE_STEPS": 10, + "SEED": 2027 +} \ No newline at end of file diff --git a/records/track_non_record_16mb/2026-04-18_SP1024_ShareVLast3_3Seed/requirements.txt b/records/track_non_record_16mb/2026-04-18_SP1024_ShareVLast3_3Seed/requirements.txt new file mode 100644 index 0000000000..911b0e52f0 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-18_SP1024_ShareVLast3_3Seed/requirements.txt @@ -0,0 +1,10 @@ +numpy +tqdm +torch +huggingface-hub +kernels +setuptools +typing-extensions==4.15.0 +datasets +tiktoken +sentencepiece \ No newline at end of file diff --git a/records/track_non_record_16mb/2026-04-18_SP1024_ShareVLast3_3Seed/seed_runs.csv b/records/track_non_record_16mb/2026-04-18_SP1024_ShareVLast3_3Seed/seed_runs.csv new file mode 100644 index 0000000000..db4903a3cf --- /dev/null +++ b/records/track_non_record_16mb/2026-04-18_SP1024_ShareVLast3_3Seed/seed_runs.csv @@ -0,0 +1,4 @@ +experiment,seed,last_val_bpb,roundtrip_val_bpb,roundtrip_exact_val_bpb,submission_bytes,stopped_step,output_dir +submission_seed1337,1337,1.2791,1.2808,1.28079096,15972114,3000,output/submission_sharev3_3seed/submission_seed1337_20260418_100202 +submission_seed2027,2027,1.2755,1.2772,1.27717259,15973626,3000,output/submission_sharev3_3seed/submission_seed2027_20260418_103507 +submission_seed3407,3407,1.2779,1.2795,1.27952108,15975453,3000,output/submission_sharev3_3seed/submission_seed3407_20260418_110812 diff --git a/records/track_non_record_16mb/2026-04-18_SP1024_ShareVLast3_3Seed/submission.json b/records/track_non_record_16mb/2026-04-18_SP1024_ShareVLast3_3Seed/submission.json new file mode 100644 index 0000000000..593c1e3b09 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-18_SP1024_ShareVLast3_3Seed/submission.json @@ -0,0 +1,17 @@ +{ + "title": "SP1024 + Shared-V(last3) + BIFPN2 + XSA4 + NGram Fade", + "author": "Kaikai Liu", + "github_id": "lkk688", + "track": "non-record-16mb", + "description": "Stable SP1024 non-record submission under the 16MB artifact cap.", + "val_bpb": 1.27717259, + "artifact_bytes": 15973626, + "representative_seed": 2027, + "seeds": [ + 1337, + 2027, + 3407 + ], + "tokenizer": "official fineweb_1024_bpe.model", + "dataset": "official fineweb10B_sp1024" +} diff --git a/records/track_non_record_16mb/2026-04-18_SP1024_ShareVLast3_3Seed/train.log b/records/track_non_record_16mb/2026-04-18_SP1024_ShareVLast3_3Seed/train.log new file mode 100644 index 0000000000..9dcace2b19 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-18_SP1024_ShareVLast3_3Seed/train.log @@ -0,0 +1,76 @@ +output/submission_sharev3_3seed/submission_seed2027_20260418_103507/20260418_103511.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +Architecture: Discrete N-Gram Hash (Max N=2) +Architecture: StructuredGroupSignedBiFPN (groups=8, band=1) +model_params:17390313 +world_size:1 grad_accum_steps:4 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:3600.000 +seed:2027 +Architecture Skip Mode: Symmetric U-Net +Enhancement: Discrete N-Gram Hash (Max N=2) +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +EMA Enabled: decay=0.997 +Scheduled Late QAT to start at step 2550 (last 15.0%) +step:0/3000 val_loss:6.9310 val_bpb:4.1049 train_time:5ms step_avg:4.66ms +step:1/3000 train_loss:6.9310 train_time:4328ms step_avg:4328.03ms +step:2/3000 train_loss:6.7809 train_time:7824ms step_avg:3912.03ms +step:3/3000 train_loss:6.3509 train_time:8434ms step_avg:2811.31ms +step:4/3000 train_loss:6.0286 train_time:9048ms step_avg:2262.10ms +step:5/3000 train_loss:5.8585 train_time:9663ms step_avg:1932.65ms +step:6/3000 train_loss:5.7350 train_time:10276ms step_avg:1712.72ms +step:7/3000 train_loss:5.6178 train_time:10890ms step_avg:1555.74ms +step:8/3000 train_loss:5.5590 train_time:11507ms step_avg:1438.42ms +step:9/3000 train_loss:5.4568 train_time:12123ms step_avg:1346.98ms +step:10/3000 train_loss:5.3681 train_time:12735ms step_avg:1273.49ms +step:200/3000 train_loss:2.7164 train_time:130023ms step_avg:650.12ms +step:400/3000 train_loss:2.3737 train_time:253575ms step_avg:633.94ms +step:600/3000 train_loss:2.4822 train_time:377129ms step_avg:628.55ms +step:800/3000 train_loss:2.3391 train_time:500560ms step_avg:625.70ms +step:1000/3000 train_loss:2.3517 train_time:624011ms step_avg:624.01ms +step:1000/3000 val_loss:2.3286 val_bpb:1.3791 train_time:624012ms step_avg:624.01ms +step:1200/3000 train_loss:2.2892 train_time:747509ms step_avg:622.92ms +step:1400/3000 train_loss:2.3483 train_time:870940ms step_avg:622.10ms +step:1600/3000 train_loss:2.2245 train_time:994295ms step_avg:621.43ms +step:1800/3000 train_loss:2.2709 train_time:1117660ms step_avg:620.92ms +step:2000/3000 train_loss:2.2152 train_time:1240992ms step_avg:620.50ms +step:2000/3000 val_loss:2.2320 val_bpb:1.3219 train_time:1240994ms step_avg:620.50ms +step:2200/3000 train_loss:2.1446 train_time:1364383ms step_avg:620.17ms +step:2400/3000 train_loss:2.1720 train_time:1487687ms step_avg:619.87ms +[Step 2550] Activating Late QAT — enabling branchless STE quantization. +step:2600/3000 train_loss:2.2332 train_time:1610963ms step_avg:619.60ms +step:2800/3000 train_loss:2.1837 train_time:1734230ms step_avg:619.37ms +step:3000/3000 train_loss:2.1042 train_time:1857569ms step_avg:619.19ms +step:3000/3000 val_loss:2.1537 val_bpb:1.2755 train_time:1857570ms step_avg:619.19ms +peak memory allocated: 22182 MiB reserved: 24640 MiB +Applying EMA weights for final evaluation... +Serialized model: 67895209 bytes +Code size: 126855 bytes +Total submission size: 68022064 bytes +Serialized model int8+zlib: 15846771 bytes (payload:17577610 raw_torch:17627197 payload_ratio:3.86x) +Total submission size int8+zlib: 15973626 bytes +final_int8_zlib_roundtrip val_loss:2.1565 val_bpb:1.2772 eval_time:18650ms +final_int8_zlib_roundtrip_exact val_loss:2.15645243 val_bpb:1.27717259 diff --git a/records/track_non_record_16mb/2026-04-18_SP1024_ShareVLast3_3Seed/train_gpt.py b/records/track_non_record_16mb/2026-04-18_SP1024_ShareVLast3_3Seed/train_gpt.py new file mode 100644 index 0000000000..e6f1d7cec7 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-18_SP1024_ShareVLast3_3Seed/train_gpt.py @@ -0,0 +1,2654 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +from xml.parsers.expat import model +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.profiler +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +import json # === NEW: For telemetry logging === + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + + # NEW: Sliding window validation control + eval_use_sliding_window = bool(int(os.environ.get("EVAL_USE_SLIDING_WINDOW", "0"))) + eval_stride = int(os.environ.get("EVAL_STRIDE", "128")) #64, 128, 1024 + #With TRAIN_SEQ_LEN=1024, EVAL_STRIDE=1024 means no real overlap. That mostly defeats the purpose of sliding eval. + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", "16")) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # === NEW: TELEMETRY AND ABLATION CONTROL === + # Set FDA_MODE=1 to use Forward Dense Addition skips instead of U-Net skips. + fda_mode = bool(int(os.environ.get("FDA_MODE", "0"))) + # Set to > 0 (e.g., 50) to log internal states without hurting speed. + output_dir = os.environ.get("OUTPUT_DIR", "") + telemetry_every = int(os.environ.get("TELEMETRY_EVERY", "0")) + telemetry_file = os.environ.get("TELEMETRY_FILE", os.path.join(output_dir, "telemetry.jsonl") if output_dir else "logs/telemetry.jsonl") + + # === NEW: XSA & Profiling Suite === + xsa_enabled = bool(int(os.environ.get("XSA_ENABLED", "0"))) + xsa_last_n_layers = int(os.environ.get("XSA_LAST_N_LAYERS", "0")) + xsa_eps = float(os.environ.get("XSA_EPS", "1e-6")) + + profile_run = bool(int(os.environ.get("PROFILE_RUN", "0"))) + profile_warmup_steps = int(os.environ.get("PROFILE_WARMUP_STEPS", "5")) + profile_active_steps = int(os.environ.get("PROFILE_ACTIVE_STEPS", "10")) + profile_output_dir = os.environ.get("PROFILE_OUTPUT_DIR", "output/prof_base") + + # === NEW: MTP Control === + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", "0")) # Set to e.g., 2 to predict t+2 and t+3 + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", "0.2")) + mtphead_mlpmode = bool(int(os.environ.get("MTPHEAD_MLPMODE", "0"))) + + # === 新增:N-Gram / SMEAR 控制 === + ngram_vocab_size = int(os.environ.get("NGRAM_VOCAB_SIZE", "2048")) + ngram_dim = int(os.environ.get("NGRAM_DIM", "128")) + ngram_max_n = int(os.environ.get("NGRAM_MAX_N", "4")) # 2=Bigram, 3=Trigram, 4=4-gram + smear_mode = bool(int(os.environ.get("SMEAR_MODE", "0"))) + smear_window = int(os.environ.get("SMEAR_WINDOW", "4")) + + # === N-Gram fade-out schedule === + ngram_fade_enable = bool(int(os.environ.get("NGRAM_FADE_ENABLE", "0"))) + ngram_fade_start_frac = float(os.environ.get("NGRAM_FADE_START_FRAC", "0.15")) + ngram_fade_end_frac = float(os.environ.get("NGRAM_FADE_END_FRAC", "0.45")) + ngram_fade_min_scale = float(os.environ.get("NGRAM_FADE_MIN_SCALE", "0.0")) + + # === 新增: 架构微调标志 === + rope_dims = int(os.environ.get("ROPE_DIMS", "-1")) # -1 表示由 head_dim 决定 + learnable_rope = bool(int(os.environ.get("LEARNABLE_ROPE", "0"))) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + learnable_ln_scale = bool(int(os.environ.get("LEARNABLE_LN_SCALE", "0"))) + scaledlm_head = bool(int(os.environ.get("SCALEDLM_HEAD", "1"))) + bifpn_mode = bool(int(os.environ.get("BIFPN_MODE", "0"))) + affine_norm = bool(int(os.environ.get("AFFINE_NORM", "0"))) + smear_gate = bool(int(os.environ.get("SMEAR_GATE", "0"))) + late_qat_ratio = float(os.environ.get("LATE_QAT_RATIO", "0.15")) + stop_mode = os.environ.get("STOP_MODE", "walltime") # walltime | steps + max_train_steps = int(os.environ.get("MAX_TRAIN_STEPS", "0")) + + # === Structured sparse + group-wise signed BiFPN === + bifpn2_mode = bool(int(os.environ.get("BIFPN2_MODE", "0"))) + bifpn_group_count = int(os.environ.get("BIFPN_GROUP_COUNT", "8")) + bifpn_band_width = int(os.environ.get("BIFPN_BAND_WIDTH", "1")) # 0=only symmetric, 1=neighbor band + bifpn_norm_eps = float(os.environ.get("BIFPN_NORM_EPS", "1e-4")) + bifpn_init_main = float(os.environ.get("BIFPN_INIT_MAIN", "1.0")) + bifpn_init_neighbor = float(os.environ.get("BIFPN_INIT_NEIGHBOR", "0.15")) + bifpn_init_far = float(os.environ.get("BIFPN_INIT_FAR", "0.0")) + + # === Value path / memory path research flags === + v_skip_enabled = bool(int(os.environ.get("V_SKIP_ENABLED", "0"))) + v_skip_last_n_layers = int(os.environ.get("V_SKIP_LAST_N_LAYERS", "0")) + v_skip_mode = os.environ.get("V_SKIP_MODE", "scalar") # scalar|group + v_skip_group_count = int(os.environ.get("V_SKIP_GROUP_COUNT", "8")) + + cross_layer_v_enabled = bool(int(os.environ.get("CROSS_LAYER_V_ENABLED", "0"))) + cross_layer_v_last_n_layers = int(os.environ.get("CROSS_LAYER_V_LAST_N_LAYERS", "0")) + cross_layer_v_mode = os.environ.get("CROSS_LAYER_V_MODE", "residual") # residual|blend + cross_layer_v_group_count = int(os.environ.get("CROSS_LAYER_V_GROUP_COUNT", "8")) + + cross_layer_kv_sharing_enabled = bool(int(os.environ.get("CROSS_LAYER_KV_SHARING_ENABLED", "0"))) + cross_layer_kv_last_n_layers = int(os.environ.get("CROSS_LAYER_KV_LAST_N_LAYERS", "0")) + cross_layer_kv_share_k = bool(int(os.environ.get("CROSS_LAYER_KV_SHARE_K", "1"))) + cross_layer_kv_share_v = bool(int(os.environ.get("CROSS_LAYER_KV_SHARE_V", "1"))) + cross_layer_kv_pairwise = bool(int(os.environ.get("CROSS_LAYER_KV_PAIRWISE", "0"))) + cross_layer_kv_partial_head = bool(int(os.environ.get("CROSS_LAYER_KV_PARTIAL_HEAD", "0"))) + cross_layer_kv_partial_head_count = int(os.environ.get("CROSS_LAYER_KV_PARTIAL_HEAD_COUNT", "2")) + + +@torch.compile(dynamic=False, fullgraph=True) +def update_ema_fused(ema_tensors: list[Tensor], model_tensors: list[Tensor], decay: float): + for e, m in zip(ema_tensors, model_tensors): + e.mul_(decay).add_(m.float(), alpha=1.0 - decay) +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def tokens_to_bytes_count( + xb: Tensor, + yb: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> Tensor: + prev_ids = xb.reshape(-1) + tgt_ids = yb.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + return token_bytes.sum() + + +@torch.no_grad() +def eval_val_sliding( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + """ + Sliding-window validation. + + Each scored token is counted exactly once. + For each overlapping window, only the rightmost `eval_stride` target tokens + are scored, so tokens are evaluated with near-max left context. + """ + + model.eval() + + seq_len = args.train_seq_len + stride = args.eval_stride + batch_seqs = args.eval_batch_seqs + + if stride <= 0: + raise ValueError(f"EVAL_STRIDE must be > 0, got {stride}") + if stride > seq_len: + raise ValueError( + f"EVAL_STRIDE must be <= TRAIN_SEQ_LEN, got stride={stride}, seq_len={seq_len}" + ) + if batch_seqs <= 0: + raise ValueError(f"EVAL_BATCH_SEQS must be > 0, got {batch_seqs}") + + # val_tokens has shape [N+1], where x = tokens[:-1], y = tokens[1:] + total_loss_sum = torch.zeros(1, device=device, dtype=torch.float64) + total_token_count = torch.zeros(1, device=device, dtype=torch.float64) + total_byte_count = torch.zeros(1, device=device, dtype=torch.float64) + + max_start = val_tokens.numel() - 1 - seq_len + if max_start < 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + + starts = list(range(0, max_start + 1, stride)) + if starts[-1] != max_start: + starts.append(max_start) + + # shard starts across ranks + starts = starts[rank::world_size] + + def _score_batch(xb: Tensor, yb: Tensor) -> tuple[Tensor, Tensor, Tensor]: + # xb: [B, T], yb: [B, T] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + _, loss_tokens = model(xb, yb, reduction="none") + # loss_tokens expected shape [B, T] + score_loss = loss_tokens[:, -stride:] # only score rightmost stride positions + + # token count + token_count = torch.tensor(score_loss.numel(), device=device, dtype=torch.float64) + + # byte count + scored_y = yb[:, -stride:] + scored_x = xb[:, -stride:] # Required to find inter-token boundaries + byte_count = tokens_to_bytes_count( + scored_x, + scored_y, + base_bytes_lut=base_bytes_lut, + has_leading_space_lut=has_leading_space_lut, + is_boundary_token_lut=is_boundary_token_lut, + ).to(torch.float64) + + loss_sum = score_loss.sum(dtype=torch.float64) + return loss_sum, token_count, byte_count + + batch_x = [] + batch_y = [] + + for start in starts: + chunk = val_tokens[start : start + seq_len + 1].to(device=device, dtype=torch.int64, non_blocking=True) + x = chunk[:-1] + y = chunk[1:] + + batch_x.append(x) + batch_y.append(y) + + if len(batch_x) == batch_seqs: + xb = torch.stack(batch_x) + yb = torch.stack(batch_y) + + loss_sum, token_count, byte_count = _score_batch(xb, yb) + total_loss_sum += loss_sum + total_token_count += token_count + total_byte_count += byte_count + + batch_x.clear() + batch_y.clear() + + if batch_x: + xb = torch.stack(batch_x) + yb = torch.stack(batch_y) + + loss_sum, token_count, byte_count = _score_batch(xb, yb) + total_loss_sum += loss_sum + total_token_count += token_count + total_byte_count += byte_count + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(total_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(total_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(total_byte_count, op=dist.ReduceOp.SUM) + + val_loss = (total_loss_sum / total_token_count).item() + val_bpb = (total_loss_sum / (math.log(2.0) * total_byte_count)).item() + + model.train() + return val_loss, val_bpb + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +# def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: +# t32 = t.float() +# if t32.ndim == 2: +# # Matrices get one scale per row, which usually tracks output-channel +# # ranges much better than a single tensor-wide scale. +# clip_abs = ( +# torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) +# if t32.numel() +# else torch.empty((t32.shape[0],), dtype=torch.float32) +# ) +# clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) +# scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) +# q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() +# return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +# # Vectors / scalars use a simpler per-tensor scale. +# clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 +# scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) +# q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() +# return q, scale + +# ========================================================================= +# 原始常量保持不变,作为默认降级选项 +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +# === 新增:支持从环境变量读取动态扫描的百分位列表 === +# 比如: "100.0,99.9999,99.9995,99.999,99.99" +DYNAMIC_CLIP_Q_LIST = [ + float(p) / 100.0 + for p in os.environ.get("DYNAMIC_CLIP_PERCENTILES", "100.0").split(",") + if p.strip() +] +# ========================================================================= + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + + if t32.ndim == 2: + # === 动态百分位扫描逻辑 (GPTQ-lite 风格) === + best_q = None + best_scale = None + best_mse = float('inf') + + # 遍历所有候选的截断百分比 + for q_percentile in DYNAMIC_CLIP_Q_LIST: + if q_percentile >= 1.0: + # 100% 表示不截断,直接取绝对值最大值 + clip_abs = t32.abs().max(dim=1).values + else: + clip_abs = ( + torch.quantile(t32.abs(), q_percentile, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32, device=t32.device) + ) + + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8) + + # 反量化并计算 MSE 损失 + dequantized = q.float() * scale[:, None] + mse = torch.nn.functional.mse_loss(dequantized, t32).item() + + # 如果是列表里的第一个(默认),或者发现了更小的 MSE + if best_q is None or mse < best_mse: + best_mse = mse + best_q = q + best_scale = scale + + # 必须加 contiguous 保证序列化正常 + return best_q.contiguous(), best_scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars 维度较小,不值得做昂贵的扫描,保留原逻辑 + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +def _expand_group_gates(g: Tensor, total_dim: int) -> Tensor: + """ + g: [G] + return: [D], where D % G == 0 + """ + if total_dim % g.numel() != 0: + raise ValueError(f"total_dim ({total_dim}) must be divisible by num_groups ({g.numel()})") + group_dim = total_dim // g.numel() + return g.repeat_interleave(group_dim) + + +def apply_v_skip( + y: Tensor, + v: Tensor, + gate: Tensor, + mode: str = "scalar", + num_heads: int | None = None, + num_kv_heads: int | None = None, +) -> Tensor: + """ + y: [B, H, T, D] + v: [B, Hkv, T, D] + gate: + scalar mode: shape [1] or [] + group mode: shape [G] + Returns: + y + gated(V path) with GQA-aware expansion + """ + b, h, t, d = y.shape + hkv = v.shape[1] + + if h == hkv: + v_exp = v + else: + if num_heads is None or num_kv_heads is None: + raise ValueError("num_heads and num_kv_heads required for GQA") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + group_size = num_heads // num_kv_heads + v_exp = v.unsqueeze(2).expand(b, hkv, group_size, t, d).reshape(b, h, t, d) + + if mode == "scalar": + g = torch.sigmoid(gate.to(dtype=y.dtype)).reshape(1, 1, 1, 1) + return y + g * v_exp + elif mode == "group": + g = torch.sigmoid(gate.to(dtype=y.dtype)) + g = _expand_group_gates(g, d).view(1, 1, 1, d) + return y + g * v_exp + else: + raise ValueError(f"Unknown V_SKIP_MODE: {mode}") + + +def mix_cross_layer_v( + v_cur: Tensor, + v_prev: Tensor, + gate: Tensor, + mode: str = "residual", + group_mode: str = "scalar", +) -> Tensor: + """ + v_cur, v_prev: [B, Hkv, T, D] + mode: + residual: v_cur + g * v_prev + blend: (1-g) * v_cur + g * v_prev + group_mode: + scalar or group + """ + d = v_cur.shape[-1] + + if group_mode == "scalar": + g = torch.sigmoid(gate.to(dtype=v_cur.dtype)).reshape(1, 1, 1, 1) + elif group_mode == "group": + g = torch.sigmoid(gate.to(dtype=v_cur.dtype)) + g = _expand_group_gates(g, d).view(1, 1, 1, d) + else: + raise ValueError(f"Unknown cross-layer V group mode: {group_mode}") + + if mode == "residual": + return v_cur + g * v_prev + elif mode == "blend": + return (1.0 - g) * v_cur + g * v_prev + else: + raise ValueError(f"Unknown CROSS_LAYER_V_MODE: {mode}") + + +def apply_partial_head_sharing( + cur: Tensor, + shared: Tensor, + share_head_count: int, +) -> Tensor: + """ + cur, shared: [B, H, T, D] or [B, Hkv, T, D] + Replace the first N heads with shared heads. + """ + h = cur.shape[1] + n = min(max(share_head_count, 0), h) + if n == 0: + return cur + out = cur.clone() + out[:, :n] = shared[:, :n] + return out + +class RMSNorm(nn.Module): + def __init__(self, dim: int | None = None, eps: float | None = None, affine: bool = False): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) if (affine and dim is not None) else None + + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) if self.weight is not None else None + return F.rms_norm(x, (x.size(-1),), weight=w, eps=self.eps) + + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, bias: bool = False): + super().__init__(in_features, out_features, bias=bias) + # Branchless QAT gate: 0.0 = no-op, 1.0 = STE fake-quantize. + # A float buffer avoids any Python branch, so torch.compile(fullgraph=True) + # always traces the same single graph regardless of QAT phase. + # To activate late QAT: call module.qat_alpha.fill_(1.0) — no recompile needed. + self.register_buffer("qat_alpha", torch.tensor(0.0, dtype=torch.float32), persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + w = self.weight + # STE: forward uses w + alpha*(w_quant - w), backward flows through w. + # When qat_alpha=0 the delta is zeroed out entirely (pure fp pass-through). + w_max = w.detach().abs().amax(dim=1, keepdim=True) + scale = (w_max / 127.0).clamp_min(1e-7) + w_quant = torch.clamp(torch.round(w / scale), -127, 127) * scale + w = w + (self.qat_alpha * (w_quant - w)).detach() + return F.linear(x, w.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +# def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: +# half = x.size(-1) // 2 +# x1, x2 = x[..., :half], x[..., half:] +# return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + """ + 极速版 Partial RoPE:将多重拼接合并为单次拼接。 + """ + if rope_dims > 0 and rope_dims < x.size(-1): + half = rope_dims // 2 + # 切片全是视图 (View),零内存开销 + x1 = x[..., :half] + x2 = x[..., half:rope_dims] + x_pass = x[..., rope_dims:] + + cos_part = cos[..., :half] + sin_part = sin[..., :half] + + # 将两次 cat 合并为一次,Inductor 能将其编译为单个极速的 Triton Kernel + return torch.cat(( + x1 * cos_part + x2 * sin_part, + x1 * (-sin_part) + x2 * cos_part, + x_pass + ), dim=-1) + + # 全维度旋转 + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +def apply_xsa_gqa_efficient( + y: Tensor, + v: Tensor, + num_heads: int, + num_kv_heads: int, + eps: float = 1e-6, +) -> Tensor: + """ + Efficient XSA postprocess for GQA/MHA. + + Args: + y: attention output before final proj, shape [B, H, T, D] + v: value tensor before GQA expansion, shape [B, Hkv, T, D] + num_heads: query heads + num_kv_heads: kv heads + eps: numerical epsilon + + Returns: + y_xsa: same shape as y, with self-value-direction component removed + + Notes: + - Keeps standard attention path intact. + - Avoids repeat_interleave for GQA. + - Works for MHA too when num_heads == num_kv_heads. + """ + if num_heads == num_kv_heads: + # Standard MHA case + # Normalize self-value direction + vn = v / (v.norm(dim=-1, keepdim=True) + eps) # [B, H, T, D] + proj = (y * vn).sum(dim=-1, keepdim=True) # [B, H, T, 1] + return y - proj * vn + + # GQA case + if num_heads % num_kv_heads != 0: + raise ValueError( + f"num_heads ({num_heads}) must be divisible by num_kv_heads ({num_kv_heads})" + ) + + group_size = num_heads // num_kv_heads + + # y: [B, H, T, D] -> [B, Hkv, group, T, D] + b, h, t, d = y.shape + yg = y.view(b, num_kv_heads, group_size, t, d) + + # v: [B, Hkv, T, D] -> normalize -> [B, Hkv, 1, T, D] + vn = v / (v.norm(dim=-1, keepdim=True) + eps) + vn = vn.unsqueeze(2) + + # Remove projection onto normalized self-value direction + proj = (yg * vn).sum(dim=-1, keepdim=True) # [B, Hkv, group, T, 1] + yg = yg - proj * vn + + return yg.view(b, h, t, d) + +class StructuredGroupSignedBiFPN(nn.Module): + """ + Structured sparse + group-wise signed fusion. + + Features: + - structured sparse connectivity via band mask + - group-wise signed weights instead of single scalar + - normalized by sum of absolute weights per decoder/group + + Shape convention: + skips: list of encoder features, each [B, T, D] + output for one decoder layer: [B, T, D] + """ + def __init__( + self, + num_decoder_layers: int, + num_encoder_layers: int, + model_dim: int, + group_count: int = 8, + band_width: int = 1, + norm_eps: float = 1e-4, + init_main: float = 1.0, + init_neighbor: float = 0.15, + init_far: float = 0.0, + ): + super().__init__() + if model_dim % group_count != 0: + raise ValueError( + f"model_dim ({model_dim}) must be divisible by group_count ({group_count})" + ) + + self.num_decoder_layers = num_decoder_layers + self.num_encoder_layers = num_encoder_layers + self.model_dim = model_dim + self.group_count = group_count + self.group_dim = model_dim // group_count + self.band_width = band_width + self.norm_eps = norm_eps + + # Signed weights per decoder, encoder, group + # shape: [Dec, Enc, G] + w = torch.full( + (num_decoder_layers, num_encoder_layers, group_count), + init_far, + dtype=torch.float32, + ) + + # Structured sparse prior: + # main symmetric connection + neighbor band + for d in range(num_decoder_layers): + sym = num_encoder_layers - 1 - d + for e in range(num_encoder_layers): + dist = abs(e - sym) + if dist == 0: + w[d, e, :] = init_main + elif dist <= band_width: + w[d, e, :] = init_neighbor + + # Keep a binary mask for allowed connections + mask = torch.zeros( + (num_decoder_layers, num_encoder_layers, 1), + dtype=torch.float32, + ) + for d in range(num_decoder_layers): + sym = num_encoder_layers - 1 - d + for e in range(num_encoder_layers): + if abs(e - sym) <= band_width: + mask[d, e, 0] = 1.0 + + self.weights = nn.Parameter(w) + self.register_buffer("mask", mask, persistent=True) + + def forward(self, skips: list[Tensor], decoder_idx: int, x_dtype: torch.dtype) -> Tensor: + """ + skips: list of encoder outputs, len = Enc, each [B,T,D] + decoder_idx: which decoder layer is requesting fusion + returns fused skip feature [B,T,D] + """ + if len(skips) != self.num_encoder_layers: + raise ValueError( + f"Expected {self.num_encoder_layers} skips, got {len(skips)}" + ) + + # [Enc, B, T, D] + stacked = torch.stack(skips, dim=0) + + # reshape feature dim into groups: [Enc, B, T, G, Gd] + enc, b, t, d = stacked.shape + stacked_g = stacked.view(enc, b, t, self.group_count, self.group_dim) + + # signed group-wise weights for this decoder: [Enc, G] + w = self.weights[decoder_idx] * self.mask[decoder_idx] # [Enc, G] + w = w.to(dtype=x_dtype) + + # Normalize by sum of abs weights per group + denom = w.abs().sum(dim=0, keepdim=True).clamp_min(self.norm_eps) # [1, G] + w_norm = w / denom # [Enc, G] + + # weighted sum: [Enc,G] x [Enc,B,T,G,Gd] -> [B,T,G,Gd] + fused = torch.einsum("eg,ebtgd->btgd", w_norm, stacked_g) + + # back to [B,T,D] + fused = fused.reshape(b, t, d) + return fused + + @torch.no_grad() + def export_effective_matrix(self) -> Tensor: + """ + Returns decoder-encoder scalar summary matrix [Dec, Enc] + by averaging group-wise signed weights after normalization. + Useful for visualization. + """ + w = self.weights * self.mask + denom = w.abs().sum(dim=1, keepdim=True).clamp_min(self.norm_eps) # [Dec,1,G] + w_norm = w / denom + return w_norm.mean(dim=-1) # [Dec, Enc] + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + args, + layer_idx: int, + xsa_enabled: bool = False, + xsa_eps: float = 1e-6, + ): + super().__init__() + + dim = args.model_dim + num_heads = args.num_heads + num_kv_heads = args.num_kv_heads + rope_base = args.rope_base + qk_gain_init = args.qk_gain_init + + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + + self.layer_idx = layer_idx + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + + self.rope_dims = args.rope_dims if args.rope_dims > 0 else self.head_dim + self.xsa_enabled = xsa_enabled + self.xsa_eps = xsa_eps + + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + self.learnable_rope = args.learnable_rope + if self.learnable_rope: + init_logits = torch.full((self.head_dim // 2,), -4.0, dtype=torch.float32) + init_logits[:8] = 4.0 + self.rope_mix_logits = nn.Parameter(init_logits) + + # ----------------------------- + # Value path research flags + # ----------------------------- + self.v_skip_enabled = args.v_skip_enabled and (layer_idx >= args.num_layers - args.v_skip_last_n_layers) + self.v_skip_mode = args.v_skip_mode + self.v_skip_group_count = args.v_skip_group_count + + if self.v_skip_enabled: + if self.v_skip_mode == "scalar": + self.v_skip_gate = nn.Parameter(torch.zeros(1, dtype=torch.float32)) + elif self.v_skip_mode == "group": + self.v_skip_gate = nn.Parameter(torch.zeros(self.v_skip_group_count, dtype=torch.float32)) + else: + raise ValueError(f"Unknown V_SKIP_MODE: {self.v_skip_mode}") + + self.cross_layer_v_enabled = ( + args.cross_layer_v_enabled and + (layer_idx >= args.num_layers - args.cross_layer_v_last_n_layers) + ) + self.cross_layer_v_mode = args.cross_layer_v_mode + self.cross_layer_v_group_count = args.cross_layer_v_group_count + + if self.cross_layer_v_enabled: + if args.cross_layer_v_group_count <= 1: + self.cross_layer_v_gate = nn.Parameter(torch.zeros(1, dtype=torch.float32)) + self.cross_layer_v_gate_mode = "scalar" + else: + self.cross_layer_v_gate = nn.Parameter(torch.zeros(args.cross_layer_v_group_count, dtype=torch.float32)) + self.cross_layer_v_gate_mode = "group" + + # ----------------------------- + # Memory path research flags + # ----------------------------- + self.cross_layer_kv_sharing_enabled = ( + args.cross_layer_kv_sharing_enabled and + (layer_idx >= args.num_layers - args.cross_layer_kv_last_n_layers) + ) + self.cross_layer_kv_share_k = args.cross_layer_kv_share_k + self.cross_layer_kv_share_v = args.cross_layer_kv_share_v + self.cross_layer_kv_pairwise = args.cross_layer_kv_pairwise + self.cross_layer_kv_partial_head = args.cross_layer_kv_partial_head + self.cross_layer_kv_partial_head_count = args.cross_layer_kv_partial_head_count + + + def forward( + self, + x: Tensor, + shared_k: Tensor | None = None, + shared_v: Tensor | None = None, + prev_v: Tensor | None = None, + ) -> tuple[Tensor, Tensor, Tensor]: + """ + Returns: + out: [B,T,D] + k_eff: [B,Hkv,T,D] + v_eff: [B,Hkv,T,D] + """ + bsz, seqlen, dim = x.shape + + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + + cos, sin = self.rotary(seqlen, x.device, q.dtype) + + if self.learnable_rope: + q_rotated = apply_rotary_emb(q, cos, sin, rope_dims=0) + k_rotated = apply_rotary_emb(k, cos, sin, rope_dims=0) + gamma = torch.sigmoid(self.rope_mix_logits.to(q.dtype)) + gamma = gamma.unsqueeze(-1).expand(-1, 2).reshape(-1) + q = gamma * q_rotated + (1 - gamma) * q + k = gamma * k_rotated + (1 - gamma) * k + else: + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + + # ---------------------------------------------------- + # Memory path: cross-layer KV sharing + # ---------------------------------------------------- + k_eff = k + v_eff = v + + if self.cross_layer_kv_sharing_enabled: + if self.cross_layer_kv_share_k and shared_k is not None: + if self.cross_layer_kv_partial_head: + k_eff = apply_partial_head_sharing(k_eff, shared_k, self.cross_layer_kv_partial_head_count) + else: + k_eff = shared_k + + if self.cross_layer_kv_share_v and shared_v is not None: + if self.cross_layer_kv_partial_head: + v_eff = apply_partial_head_sharing(v_eff, shared_v, self.cross_layer_kv_partial_head_count) + else: + v_eff = shared_v + + # ---------------------------------------------------- + # Value path: cross-layer V sharing + # ---------------------------------------------------- + if self.cross_layer_v_enabled and prev_v is not None: + v_eff = mix_cross_layer_v( + v_cur=v_eff, + v_prev=prev_v, + gate=self.cross_layer_v_gate, + mode=self.cross_layer_v_mode, + group_mode=self.cross_layer_v_gate_mode, + ) + + y = F.scaled_dot_product_attention( + q, + k_eff, + v_eff, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + + if self.xsa_enabled: + y = apply_xsa_gqa_efficient( + y=y, + v=v_eff, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + eps=self.xsa_eps, + ) + + if self.v_skip_enabled: + y = apply_v_skip( + y=y, + v=v_eff, + gate=self.v_skip_gate, + mode=self.v_skip_mode, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + ) + + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + out = self.proj(y) + return out, k_eff, v_eff + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + args, + layer_idx=0, + xsa_enabled=False, + xsa_eps=1e-6 + ): + super().__init__() + dim = args.model_dim + self.layer_idx = layer_idx + + self.attn_norm = RMSNorm(dim, affine=args.affine_norm) + self.mlp_norm = RMSNorm(dim, affine=args.affine_norm) + self.attn = CausalSelfAttention( + args, + layer_idx=layer_idx, + xsa_enabled=xsa_enabled, + xsa_eps=xsa_eps, + ) + self.mlp = MLP(dim, args.mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + # === 新增:环境变量控制的 Layer Scale === + self.learnable_ln_scale = args.learnable_ln_scale + + # 计算初始值 (衰减系数) + #init_scale = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + # === 利用实验得出的经验公式:让深层衰减得更快一点,压制后期的方差爆炸 === + # 原版:1.0 / math.sqrt(layer_idx + 1) + # 优化版先验: + init_scale = 1.0 / (math.sqrt(layer_idx + 1) + 0.1 * layer_idx) if args.ln_scale else 1.0 + + self.layer_scale = init_scale # 直接作为浮点数乘上去,无需学习 + + if self.learnable_ln_scale: + # 声明为 1D 参数,以便被 PyTorch 的 parameters() 追踪 + self.layer_scale = nn.Parameter(torch.tensor([init_scale], dtype=torch.float32)) + else: + # 声明为普通 Python 标量,不参与梯度更新 + self.layer_scale = init_scale + + def forward( + self, + x: Tensor, + x0: Tensor, + shared_k: Tensor | None = None, + shared_v: Tensor | None = None, + prev_v: Tensor | None = None, + ) -> tuple[Tensor, Tensor, Tensor]: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + scale = self.layer_scale.to(dtype=x.dtype) if isinstance(self.layer_scale, nn.Parameter) else self.layer_scale + + attn_out, k_eff, v_eff = self.attn( + self.attn_norm(x) * scale, + shared_k=shared_k, + shared_v=shared_v, + prev_v=prev_v, + ) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * scale) + return x, k_eff, v_eff + + # attn_out = self.attn(self.attn_norm(x)) + # x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + # x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + # return x + + +# === NEW: ABLATION ARCHITECTURE IN GPT CLASS === +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +class CausalLocalMixing(nn.Module): + """ + Causal local context mixing via depthwise conv1d. + Per-channel learned softmax weights over a causal window. + Compiles to a single fused kernel instead of window_size tiny kernels. + """ + def __init__(self, dim: int, window_size: int = 4): + super().__init__() + self.window_size = window_size + self.dim = dim + # Learnable logits [window_size, dim]: position 0 = current token + w = torch.zeros(window_size, dim, dtype=torch.float32) + w[0, :] = 3.0 # Initial bias toward current token (~0.95) + self.mix_logits = nn.Parameter(w) + + def forward(self, x: Tensor) -> Tensor: + if self.window_size <= 1: + return x + # x: [B, T, D] -> conv1d expects [B, D, T] + # Build depthwise conv weights from softmax logits: [D, 1, W] + w_soft = F.softmax(self.mix_logits.to(x.dtype), dim=0) # [W, D] + # Flip so that index 0 (current token) is the rightmost tap + kernel = w_soft.flip(0).T.unsqueeze(1) # [D, 1, W] + # Left-pad for causal convolution + x_t = x.transpose(1, 2) # [B, D, T] + x_padded = F.pad(x_t, (self.window_size - 1, 0)) # [B, D, T+W-1] + out = F.conv1d(x_padded, kernel, groups=self.dim) # [B, D, T] + return out.transpose(1, 2) # [B, T, D] + +class NGramHashEmbedding(nn.Module): + def __init__(self, vocab_size: int, dim: int, model_dim: int, max_n: int = 4): + """ + max_n: 支持到几元组。比如 max_n=4 代表同时包含 Bigram(2), Trigram(3), 4-gram(4) + """ + super().__init__() + self.max_n = max_n + self.vocab_size = vocab_size + + # 为每一种 N-gram 创建一个独立的 Embedding 表 + self.embeds = nn.ModuleList([ + nn.Embedding(vocab_size, dim) for _ in range(2, max_n + 1) + ]) + for emb in self.embeds: + # nn.init.zeros_(emb.weight) + nn.init.normal_(emb.weight, std=0.01) # 给一点微小的初始特征,打破零梯度僵局 + + self.proj = nn.Linear(dim, model_dim, bias=False) if dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + + # 为每一个 N-gram 分配一个独立的可学习权重,初始给 0.05 + self.ngram_scales = nn.Parameter(torch.full((max_n - 1,), 0.05, dtype=torch.float32)) + + def ngram_hash(self, tokens: Tensor, n: int) -> Tensor: + t = tokens.to(torch.int64) # 使用 int64 防止乘法溢出 + mod = self.vocab_size - 1 + out = torch.empty_like(t) + out[..., :n-1] = mod + + primes = [36313, 27191, 19393, 13127, 9767] + hash_val = t[..., n-1:] * primes[0] + for i in range(1, n): + hash_val = torch.bitwise_xor(hash_val, t[..., n-1-i : -i] * primes[i]) + + out[..., n-1:] = hash_val % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + fused_h = None + for idx, n in enumerate(range(2, self.max_n + 1)): + hashed_ids = self.ngram_hash(token_ids, n) + h_n = self.embeds[idx](hashed_ids) + scaled_h = h_n * self.ngram_scales[idx].to(dtype=h_n.dtype) + + if fused_h is None: + fused_h = scaled_h + else: + fused_h = fused_h + scaled_h + + if self.proj is not None: + fused_h = self.proj(fused_h) + return fused_h + + +def compute_ngram_fade_scale( + step: int, + total_steps: int, + enabled: bool, + start_frac: float, + end_frac: float, + min_scale: float = 0.0, +) -> float: + """ + Piecewise linear fade-out schedule for N-Gram features. + + Before start_frac: scale = 1 + Between start_frac and end_frac: linearly decay to min_scale + After end_frac: scale = min_scale + """ + if not enabled: + return 1.0 + + if total_steps <= 0: + return 1.0 + + p = step / float(total_steps) + start_frac = max(0.0, min(1.0, start_frac)) + end_frac = max(start_frac + 1e-8, min(1.0, end_frac)) + min_scale = max(0.0, min(1.0, min_scale)) + + if p <= start_frac: + return 1.0 + if p >= end_frac: + return min_scale + + alpha = (p - start_frac) / (end_frac - start_frac) + return (1.0 - alpha) * 1.0 + alpha * min_scale + +class GPT(nn.Module): + def __init__(self, args, master_process: bool = True): + super().__init__() + self.fda_mode = args.fda_mode + self.skip_distance = 2 # Configurable skip distance for FDA mode + + model_dim = args.model_dim + num_layers = args.num_layers + self.num_layers = args.num_layers + self.cross_layer_kv_sharing_enabled = args.cross_layer_kv_sharing_enabled + self.cross_layer_kv_last_n_layers = args.cross_layer_kv_last_n_layers + self.cross_layer_kv_pairwise = args.cross_layer_kv_pairwise + + + self.ln_scale = args.ln_scale + self.scaledlm_head = args.scaledlm_head + self.mtphead_mlpmode = args.mtphead_mlpmode + + self.tie_embeddings = args.tie_embeddings + self.tied_embed_init_std = args.tied_embed_init_std + self.logit_softcap = args.logit_softcap + self.tok_emb = nn.Embedding(args.vocab_size, model_dim) + + self.smear_mode = args.smear_mode + self.smear_window = args.smear_window + if self.smear_mode: + self.local_mix = CausalLocalMixing(model_dim, window_size=self.smear_window) + if master_process: + print(f"Architecture: Local Causal Mixing (Window={self.smear_window})") + + self.smear_gate = args.smear_gate + if self.smear_gate: + self.smear_gate_module = SmearGate(model_dim) + if master_process: + print(f"Architecture: SmearGate (1-step causal blend)") + + self.ngram_max_n = args.ngram_max_n + + if args.ngram_vocab_size > 0 and self.ngram_max_n >= 2: + self.ngram = NGramHashEmbedding(args.ngram_vocab_size, args.ngram_dim, model_dim, max_n=self.ngram_max_n) + if master_process: + print(f"Architecture: Discrete N-Gram Hash (Max N={self.ngram_max_n})") + else: + self.ngram = None + + self.register_buffer("ngram_global_scale_buf", torch.tensor(1.0, dtype=torch.float32), persistent=False) + + self.blocks = nn.ModuleList( + [ + Block( + args, + layer_idx=i, + xsa_enabled=(args.xsa_enabled and i >= num_layers - args.xsa_last_n_layers), + xsa_eps=args.xsa_eps + ) + for i in range(num_layers) + ] + ) + + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + + # --- 新增: BiFPN 多路径加法融合开关 --- + self.bifpn_mode = args.bifpn_mode + self.bifpn2_mode = args.bifpn2_mode + if self.bifpn_mode: + w = torch.full((self.num_decoder_layers, self.num_encoder_layers), 0.1, dtype=torch.float32) + # 1. 基础的对称 U-Net 连接 (对角线) + for i in range(self.num_decoder_layers): + sym_idx = self.num_encoder_layers - 1 - i + if sym_idx >= 0: + w[i, sym_idx] = 1.0 + + # === 新增:根据先验直接注入“特征重构”结构 === + if self.num_decoder_layers >= 2 and self.num_encoder_layers >= 2: + # 倒数第一层 Decoder: 极度排斥最底层的 2 个 Encoder (做高通滤波) + w[-1, 0] = -1.5 + w[-1, 1] = -1.0 + # 倒数第二层 Decoder: 极度渴求最底层的 2 个 Encoder (做基础特征组装) + w[-2, 0] = 0.8 + w[-2, 1] = 0.5 + + self.bifpn_weights = nn.Parameter(w) + if master_process: + print("Architecture: BiFPN Weighted Addition Fusion") + elif self.bifpn2_mode: + self.structured_bifpn = StructuredGroupSignedBiFPN( + num_decoder_layers=self.num_decoder_layers, + num_encoder_layers=self.num_encoder_layers, + model_dim=model_dim, + group_count=args.bifpn_group_count, + band_width=args.bifpn_band_width, + norm_eps=args.bifpn_norm_eps, + init_main=args.bifpn_init_main, + init_neighbor=args.bifpn_init_neighbor, + init_far=args.bifpn_init_far, + ) + if int(os.environ.get("RANK", "0")) == 0: + print( + f"Architecture: StructuredGroupSignedBiFPN " + f"(groups={args.bifpn_group_count}, band={args.bifpn_band_width})" + ) + + elif self.fda_mode: + num_conn = max(0, num_layers - self.skip_distance) + self.skip_weights = nn.Parameter(torch.ones(num_conn, model_dim, dtype=torch.float32)) + else: + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + + self.final_norm = RMSNorm(model_dim, affine=args.affine_norm) + self.lm_head = None if args.tie_embeddings else CastedLinear(model_dim, args.vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + + self.mtp_num_heads = args.mtp_num_heads + self.mtp_loss_weight = args.mtp_loss_weight + + # === 终极修复:只有当头数 > 0 时,才去创建和挂载参数! === + if self.mtp_num_heads > 0: + if self.mtphead_mlpmode: + self.mtp_heads = nn.ModuleList([ + nn.Sequential( + nn.Linear(model_dim, model_dim * 2, bias=False), + nn.GELU(), + nn.Linear(model_dim * 2, args.vocab_size, bias=False) + ) for _ in range(self.mtp_num_heads) + ]) + else: + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, args.vocab_size, bias=False) for _ in range(self.mtp_num_heads)] + ) + + for head in self.mtp_heads: + if isinstance(head, nn.Sequential): + head[2]._zero_init = True + else: + head._zero_init = True + else: + # 如果配置为 0,显式设为 None,确保参数量干干净净! + self.mtp_heads = nn.ModuleList([]) + + self.max_logit_pre_cap = 0.0 # Telemetry tracking + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _should_share_from_prev_layer(self, block_idx: int) -> bool: + if not getattr(self, "cross_layer_kv_sharing_enabled", False): + return False + return block_idx >= self.num_layers - self.cross_layer_kv_last_n_layers + + def _should_pair_share(self, block_idx: int) -> bool: + return getattr(self, "cross_layer_kv_pairwise", False) and (block_idx % 2 == 1) + + + def forward(self, input_ids: Tensor, + target_ids: Tensor, + reduction: str = "mean", + ngram_global_scale: float = 1.0) -> Tensor | tuple[Tensor, Tensor]: + + last_v_for_cross_layer_v: Tensor | None = None + last_k_for_kv_sharing: Tensor | None = None + last_v_for_kv_sharing: Tensor | None = None + + x = self.tok_emb(input_ids) + + # === 叠加 N-Gram 特征 === + # if getattr(self, 'ngram', None) is not None: + # x = x + self.ngram(input_ids) + #add N-Gram fade-out + if getattr(self, 'ngram', None) is not None: + scale = self.ngram_global_scale_buf.to(dtype=x.dtype) + x = x + scale * self.ngram(input_ids) + + x = F.rms_norm(x, (x.size(-1),)) + + if self.smear_mode: + x = self.local_mix(x) + + if self.smear_gate: + x = self.smear_gate_module(x) + + x0 = x + if self.bifpn2_mode: + skips: list[Tensor] = [] + + # ----------------------------- + # Encoder + # ----------------------------- + for i in range(self.num_encoder_layers): + shared_k = None + shared_v = None + prev_v = last_v_for_cross_layer_v + + if self._should_share_from_prev_layer(i): + shared_k = last_k_for_kv_sharing + shared_v = last_v_for_kv_sharing + + x, k_eff, v_eff = self.blocks[i]( + x, + x0, + shared_k=shared_k, + shared_v=shared_v, + prev_v=prev_v, + ) + skips.append(x) + + last_v_for_cross_layer_v = v_eff + last_k_for_kv_sharing = k_eff + last_v_for_kv_sharing = v_eff + + # ----------------------------- + # Decoder with structured BiFPN2 fusion + # ----------------------------- + for i in range(self.num_decoder_layers): + fusion_feature = self.structured_bifpn( + skips=skips, + decoder_idx=i, + x_dtype=x.dtype, + ) + x = x + fusion_feature + + block_idx = self.num_encoder_layers + i + shared_k = None + shared_v = None + prev_v = last_v_for_cross_layer_v + + if self._should_share_from_prev_layer(block_idx): + if self._should_pair_share(block_idx): + shared_k = last_k_for_kv_sharing + shared_v = last_v_for_kv_sharing + else: + shared_k = last_k_for_kv_sharing + shared_v = last_v_for_kv_sharing + + x, k_eff, v_eff = self.blocks[block_idx]( + x, + x0, + shared_k=shared_k, + shared_v=shared_v, + prev_v=prev_v, + ) + + last_v_for_cross_layer_v = v_eff + last_k_for_kv_sharing = k_eff + last_v_for_kv_sharing = v_eff + + elif self.bifpn_mode: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + + stacked_skips = torch.stack(skips, dim=0) # [E, B, T, D] + for i in range(self.num_decoder_layers): + w = self.bifpn_weights[i].to(dtype=x.dtype) # [E] + fusion_feature = torch.einsum("e,ebtd->btd", w, stacked_skips) + x = x + fusion_feature + x = self.blocks[self.num_encoder_layers + i](x, x0) + + elif self.fda_mode: + # FDA: Route earlier outputs directly to future inputs + history: list[Tensor] = [] + for i, block in enumerate(self.blocks): + lookback_idx = i - self.skip_distance + if lookback_idx >= 0: + w = self.skip_weights[lookback_idx].to(dtype=x.dtype)[None, None, :] + x = x + w * history[lookback_idx] + x = block(x, x0) + history.append(x) + else: + # Baseline: Encoder/Decoder U-Net Skips + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + #x = self.final_norm(x).reshape(-1, x.size(-1)) + #targets = target_ids.reshape(-1) + x = self.final_norm(x) + + # === 新增:在展平(reshape)之前,把保持 3D 形状的特征存下来给 MTP 用 === + x_original = x + + x = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + + if self.tie_embeddings: + if self.scaledlm_head: + # 缩放点积 + logits_proj = F.linear(x, self.tok_emb.weight) / math.sqrt(x.size(-1)) + else: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + if self.scaledlm_head: + logits_proj = self.lm_head(x) / math.sqrt(x.size(-1)) + else: + logits_proj = self.lm_head(x) + + + # 遥测记录 + if not self.training or (hasattr(self, "_log_logits") and self._log_logits): + self.max_logit_pre_cap = logits_proj.detach().abs().max() + + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + if reduction == "none": + # Sliding-window eval: return per-token loss shaped [B, T] + # logits is [B*T, vocab], targets is [B*T] — compute flat then reshape + loss_flat = F.cross_entropy(logits.float(), targets, reduction="none") + loss_tokens = loss_flat.view(input_ids.shape[0], input_ids.shape[1]) + return loss_tokens.mean(), loss_tokens + + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + + # --- MTP Loss 计算 --- + if self.training and getattr(self, 'mtp_num_heads', 0) > 0 and getattr(self, 'mtp_loss_weight', 0.0) > 0.0: + _, seqlen, dim = x_original.shape + mtp_loss_sum = x_original.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + + mtp_hidden = x_original[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + + return main_loss + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_accum_steps = int(os.environ.get("GRAD_ACCUM_STEPS", grad_accum_steps)) + + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.output_dir: + os.makedirs(args.output_dir, exist_ok=True) + logfile = os.path.join(args.output_dir, f"{time.strftime('%Y%m%d_%H%M%S')}.txt") + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT(args, master_process=master_process).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + + # Single compiled graph. QAT is a branchless buffer multiply inside CastedLinear, + # so the same graph handles both phases — no recompile, no DDP rebuild at qat_start_step. + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding & N-gram embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars/auxiliary heads use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + + # if base_model.skip_weights.numel() > 0: + # scalar_params.append(base_model.skip_weights) + # 1. 兼容各种架构的跳连权重 + if hasattr(base_model, 'skip_weights') and base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if hasattr(base_model, 'bifpn_weights') and base_model.bifpn_weights.numel() > 0: + scalar_params.append(base_model.bifpn_weights) + if hasattr(base_model, 'structured_bifpn'): + scalar_params.append(base_model.structured_bifpn.weights) + + # 2. 将 MTP 的多层网络参数全部加入 scalar_params (修复了之前的 BUG) + if hasattr(base_model, 'mtp_heads') and base_model.mtp_heads is not None: + for p in base_model.mtp_heads.parameters(): + scalar_params.append(p) + # if hasattr(base_model, 'mtp_heads'): + # for head in base_model.mtp_heads: + # scalar_params.append(head.weight) + + # 3. 将 N-Gram 的投影层和小尺度参数加入 scalar_params + if hasattr(base_model, 'ngram') and base_model.ngram is not None: + if base_model.ngram.proj is not None: + scalar_params.append(base_model.ngram.proj.weight) + scalar_params.append(base_model.ngram.ngram_scales) + + +# 4. 配置 Token Optimizer (将普通的 Embedding 和 N-Gram Embedding 放在一起) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_param_groups = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + + if hasattr(base_model, 'ngram') and base_model.ngram is not None: + for emb in base_model.ngram.embeds: + tok_param_groups.append({"params": [emb.weight], "lr": token_lr, "base_lr": token_lr}) + + optimizer_tok = torch.optim.Adam( + tok_param_groups, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + if master_process: + mode_str = "Forward Dense (k=2)" if getattr(args, 'fda_mode', False) else "Symmetric U-Net" + if getattr(base_model, 'bifpn_mode', False): + mode_str = "BiFPN Weighted Addition Fusion" + print(f"Architecture Skip Mode: {mode_str}") + if getattr(base_model, 'ngram', None) is not None: + print(f"Enhancement: Discrete N-Gram Hash (Max N={base_model.ngram.max_n})") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + use_walltime_stop = args.stop_mode == "walltime" + use_steps_stop = args.stop_mode == "steps" + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if (use_walltime_stop and args.max_wallclock_seconds > 0) else None + # In steps mode the hard budget is max_train_steps (0 means fall back to iterations). + hard_step_limit = args.max_train_steps if (use_steps_stop and args.max_train_steps > 0) else args.iterations + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if use_steps_stop: + warmdown_start = max(hard_step_limit - args.warmdown_iters, 0) + return max((hard_step_limit - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < hard_step_limit else 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "0"))) + ema_decay = float(os.environ.get("EMA_DECAY", "0.997")) + + if ema_enabled: + log0(f"EMA Enabled: decay={ema_decay}") + # 在 CPU 或 GPU 上维护一份 FP32 的高精度影子权重 + ema_state = {name: p.detach().float().clone() for name, p in base_model.state_dict().items()} + # 预先提取列表,加速循环内读取 + ema_tensors_list = list(ema_state.values()) + model_tensors_list = list(base_model.state_dict().values()) + + # === Late QAT: compute trigger step === + qat_start_step = int(hard_step_limit * (1.0 - args.late_qat_ratio)) + if args.late_qat_ratio > 0: + log0(f"Scheduled Late QAT to start at step {qat_start_step} (last {args.late_qat_ratio*100:.1f}%)") + + + if master_process: + non_embed_params = sum(p.numel() for name, p in base_model.named_parameters() if 'tok_emb' not in name) + event_fwd_start = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None + event_bwd_end = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None + event_opt_end = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None + + # === NEW: A/B PyTorch Profiler Integration === + prof = None + if args.profile_run: + def trace_handler(p): + if master_process: + out_d = args.profile_output_dir + os.makedirs(out_d, exist_ok=True) + + # Trace + p.export_chrome_trace(os.path.join(out_d, "trace.json")) + + # Summaries + with open(os.path.join(out_d, "summary_self_cuda.txt"), "w") as f: + f.write(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=50)) + with open(os.path.join(out_d, "summary_cuda_total.txt"), "w") as f: + f.write(p.key_averages().table(sort_by="cuda_time_total", row_limit=50)) + with open(os.path.join(out_d, "summary_memory.txt"), "w") as f: + f.write(p.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=50)) + + # Extract Specific Metrics + target_ops = [ + "aten::contiguous", "aten::clone", "aten::copy_", "aten::to", + "aten::repeat_interleave", "aten::permute", "aten::reshape", "aten::view", "Torch-Compiled Region" + ] + op_counts = {op: 0 for op in target_ops} + + max_cuda_alloc = 0 + max_cuda_res = 0 + if torch.cuda.is_available(): + torch.cuda.synchronize() + max_cuda_alloc = torch.cuda.max_memory_allocated() / (1024*1024) + max_cuda_res = torch.cuda.max_memory_reserved() / (1024*1024) + + for evt in p.key_averages(): + if evt.key in op_counts: + op_counts[evt.key] += evt.count + if "CompiledFunction" in evt.key or "Torch-Compiled Region" in evt.key: + if evt.key not in op_counts: + op_counts[evt.key] = 0 + op_counts[evt.key] += evt.count + + time_sorted = sorted(p.key_averages(), key=lambda x: getattr(x, 'self_device_time_total', getattr(x, 'self_cuda_time_total', 0)), reverse=True)[:15] + top_time = [{"name": e.key, "calls": e.count, "ms": getattr(e, 'self_device_time_total', getattr(e, 'self_cuda_time_total', 0))/1000.0} for e in time_sorted] + + mem_sorted = sorted(p.key_averages(), key=lambda x: getattr(x, 'self_device_memory_usage', getattr(x, 'self_cuda_memory_usage', 0)), reverse=True)[:15] + top_mem = [{"name": e.key, "calls": e.count, "mb": getattr(e, 'self_device_memory_usage', getattr(e, 'self_cuda_memory_usage', 0))/(1024*1024)} for e in mem_sorted] + + avg_step_ms = sum(profile_step_times) / len(profile_step_times) if profile_step_times else 0.0 + p50_step_ms = float(np.percentile(profile_step_times, 50)) if profile_step_times else 0.0 + p90_step_ms = float(np.percentile(profile_step_times, 90)) if profile_step_times else 0.0 + # tokens_per_sec = (args.train_batch_tokens * args.train_seq_len) / (avg_step_ms / 1000.0) if avg_step_ms > 0 else 0.0 + #train_batch_tokens is already a token count. Multiplying by train_seq_len is incorrect. + tokens_per_sec = args.train_batch_tokens / (avg_step_ms / 1000.0) + + metrics = { + "avg_step_ms": round(avg_step_ms, 2), + "p50_step_ms": round(p50_step_ms, 2), + "p90_step_ms": round(p90_step_ms, 2), + "tokens_per_sec": round(tokens_per_sec, 2), + "max_cuda_alloc_mb": round(max_cuda_alloc, 2), + "max_cuda_res_mb": round(max_cuda_res, 2), + "op_counts": op_counts, + "top_ops_time": top_time, + "top_ops_memory": top_mem + } + + with open(os.path.join(out_d, "metrics.json"), "w") as f: + json.dump(metrics, f, indent=2) + print(f"✅ Profiling complete. Artifacts saved inside: {out_d}") + + print(f"🔍 Starting Scheduled Profiler (Warmup: {args.profile_warmup_steps}, Active: {args.profile_active_steps})...") + prof = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], + schedule=torch.profiler.schedule(wait=1, warmup=args.profile_warmup_steps, active=args.profile_active_steps, repeat=1), + on_trace_ready=trace_handler, + record_shapes=True, + profile_memory=True, + with_flops=True, + with_modules=True, + with_stack=True + ) + prof.start() + + step = 0 + profile_step_times = [] + # Initialize muon_momentum before the loop so telemetry on step 0 doesn't crash + muon_momentum = args.muon_momentum_warmup_start if args.muon_momentum_warmup_steps > 0 else args.muon_momentum + while True: + last_step = step == hard_step_limit or (stop_after_step is not None and step >= stop_after_step) + + # Suppress validation and telemetry during active profiling to keep traces clean. + # Profiler schedule: wait=1, warmup=profile_warmup_steps, active=profile_active_steps. + profiling_active = args.profile_run and step < (1 + args.profile_warmup_steps + args.profile_active_steps) + + should_validate = (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)) and not profiling_active + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + + eval_fn = eval_val_sliding if args.eval_use_sliding_window else eval_val + # Sliding window passes reduction="none" which graph-breaks torch.compile, + # so we pass base_model (uncompiled) for that path. + eval_model = base_model if args.eval_use_sliding_window else model + val_loss, val_bpb = eval_fn( + args, + eval_model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{hard_step_limit} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < hard_step_limit: + log0( + f"stopping_early: {args.stop_mode}_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{hard_step_limit}" + ) + break + + + # === Late QAT: flip buffer to 1.0 — same compiled graph, no recompile === + if step == qat_start_step and args.late_qat_ratio > 0.0: + log0(f"[Step {step}] Activating Late QAT — enabling branchless STE quantization.") + for mod in base_model.modules(): + if isinstance(mod, CastedLinear): + mod.qat_alpha.fill_(1.0) + + step_t0 = time.perf_counter() + elapsed_ms = training_time_ms + 1000.0 * (step_t0 - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + + should_telemetry = (args.telemetry_every > 0) and (step % args.telemetry_every == 0) and not profiling_active + + if should_telemetry and master_process and event_fwd_start is not None: + event_fwd_start.record() + + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + + if should_telemetry and micro_step == grad_accum_steps - 1: + base_model._log_logits = True + + ngram_global_scale = compute_ngram_fade_scale( + step=step, + total_steps=hard_step_limit, + enabled=args.ngram_fade_enable, + start_frac=args.ngram_fade_start_frac, + end_frac=args.ngram_fade_end_frac, + min_scale=args.ngram_fade_min_scale, + ) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + #loss = model(x, y) + #loss = model(x, y, ngram_global_scale=ngram_global_scale) + base_model.ngram_global_scale_buf.fill_(float(ngram_global_scale)) + loss = model(x, y) + + if should_telemetry and micro_step == grad_accum_steps - 1: + base_model._log_logits = False + + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + if should_telemetry and master_process and event_bwd_end is not None: + event_bwd_end.record() + + if should_telemetry and master_process: + with torch.no_grad(): + # --- Update how we extract the max_logit_pre_cap --- + logit_val = base_model.max_logit_pre_cap + logit_item = logit_val.item() if isinstance(logit_val, torch.Tensor) else logit_val + + telemetry_data = { + "step": step, + "train_loss": round(train_loss.item(), 4), + "max_logit_pre_cap": round(logit_item, 4), + } + + if base_model.tok_emb.weight.grad is not None: + telemetry_data["grad_norm_embed"] = round(base_model.tok_emb.weight.grad.norm().item(), 4) + + if hasattr(base_model, 'skip_weights') and base_model.skip_weights.numel() > 0: + telemetry_data["skip_weight_mean"] = round(base_model.skip_weights.mean().item(), 4) + telemetry_data["skip_weight_max"] = round(base_model.skip_weights.max().item(), 4) + + if hasattr(base_model, 'bifpn_weights') and base_model.bifpn_weights.numel() > 0: + telemetry_data["bifpn_weight_mean"] = round(base_model.bifpn_weights.mean().item(), 4) + telemetry_data["bifpn_weight_max"] = round(base_model.bifpn_weights.max().item(), 4) + # 偷窥一下模型是不是在用额外的 FPN 路径 (去对角线后的均值) + # 注意:我们这里简单算一下所有权重的均值,如果在变大,说明密集连接生效了! + + # === 新增:将局部混合权重加入标量优化器 === + # if hasattr(base_model, 'smear_mode') and base_model.smear_mode: + # scalar_params.append(base_model.local_mix.mix_logits) + + if hasattr(base_model, 'ngram') and base_model.ngram is not None: + scales = base_model.ngram.ngram_scales.detach().cpu().float().numpy() + if len(scales) > 0: telemetry_data["scale_bigram"] = round(float(scales[0]), 4) + if len(scales) > 1: telemetry_data["scale_trigram"] = round(float(scales[1]), 4) + if len(scales) > 2: telemetry_data["scale_4gram"] = round(float(scales[2]), 4) + + if hasattr(base_model, 'bifpn_mode') and base_model.bifpn_mode: + # 记录整个矩阵的摊平状态,或者直接观察非对角线元素是否变大了 + telemetry_data["bifpn_off_diag_mean"] = round( + base_model.bifpn_weights.detach().clone().fill_diagonal_(0).mean().item(), 4 + ) + + + + # Unfortunately, the way we wrote the forward pass, we only return the combined loss. + # To monitor MTP separately without breaking the compile graph, it's best to observe how the gradient norms of the mtp_heads change. + if hasattr(base_model, 'mtp_heads') and len(base_model.mtp_heads) > 0: + mtp_grad_norm = 0.0 + for p in base_model.mtp_heads.parameters(): + if p.grad is not None: + mtp_grad_norm += p.grad.norm().item() + telemetry_data["grad_norm_mtp"] = round(mtp_grad_norm, 4) + + # === Suggestion 9: XSA config telemetry === + telemetry_data["xsa_enabled"] = int(args.xsa_enabled) + telemetry_data["xsa_last_n_layers"] = args.xsa_last_n_layers + + # === Suggestion 10: Optimizer state logging === + telemetry_data["lr_warmdown_scale"] = round(scale, 6) + telemetry_data["muon_momentum"] = round(muon_momentum, 6) + # Log the actual LR being applied to each optimizer group + for opt_idx, opt in enumerate(optimizers): + for gidx, group in enumerate(opt.param_groups): + key = f"lr_opt{opt_idx}_g{gidx}" + telemetry_data[key] = round(group.get("lr", 0.0), 8) + telemetry_data["qat_active"] = int(step >= qat_start_step and args.late_qat_ratio > 0) + + telemetry_data["ngram_fade_enable"] = int(args.ngram_fade_enable) + telemetry_data["ngram_global_scale"] = round(float(ngram_global_scale), 4) + if hasattr(base_model, 'structured_bifpn'): + eff = base_model.structured_bifpn.export_effective_matrix().detach().cpu() + telemetry_data["bifpn2_eff_mean"] = round(float(eff.mean().item()), 4) + telemetry_data["bifpn2_eff_max"] = round(float(eff.max().item()), 4) + telemetry_data["bifpn2_eff_min"] = round(float(eff.min().item()), 4) + + telemetry_data["v_skip_enabled"] = int(args.v_skip_enabled) + telemetry_data["v_skip_last_n_layers"] = args.v_skip_last_n_layers + telemetry_data["cross_layer_v_enabled"] = int(args.cross_layer_v_enabled) + telemetry_data["cross_layer_v_last_n_layers"] = args.cross_layer_v_last_n_layers + telemetry_data["cross_layer_kv_sharing_enabled"] = int(args.cross_layer_kv_sharing_enabled) + telemetry_data["cross_layer_kv_last_n_layers"] = args.cross_layer_kv_last_n_layers + telemetry_data["cross_layer_kv_pairwise"] = int(args.cross_layer_kv_pairwise) + telemetry_data["cross_layer_kv_partial_head"] = int(args.cross_layer_kv_partial_head) + telemetry_data["cross_layer_kv_partial_head_count"] = args.cross_layer_kv_partial_head_count + + # ========================================== + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + + if prof is not None: + if step > args.profile_warmup_steps and step <= args.profile_warmup_steps + args.profile_active_steps: + profile_step_times.append(1000.0 * (time.perf_counter() - step_t0)) + prof.step() + + if should_telemetry and master_process and event_opt_end is not None: + event_opt_end.record() + event_opt_end.synchronize() + time_fwd_bwd_ms = event_fwd_start.elapsed_time(event_bwd_end) + time_optim_ms = event_bwd_end.elapsed_time(event_opt_end) + step_time_s = (time_fwd_bwd_ms + time_optim_ms) / 1000.0 + + tokens_per_sec = args.train_batch_tokens / max(step_time_s, 0.001) + flops_per_step = 6.0 * non_embed_params * args.train_batch_tokens + tflops_achieved = (flops_per_step / max(step_time_s, 0.001)) / 1e12 + + telemetry_data["time_fwd_bwd_ms"] = round(time_fwd_bwd_ms, 2) + telemetry_data["time_optim_ms"] = round(time_optim_ms, 2) + telemetry_data["tokens_per_sec"] = round(tokens_per_sec, 2) + telemetry_data["mfu_tflops"] = round(tflops_achieved, 2) + + if torch.cuda.is_available(): + telemetry_data["memory_allocated_mb"] = round(torch.cuda.max_memory_allocated() / (1024 * 1024), 2) + telemetry_data["memory_reserved_mb"] = round(torch.cuda.max_memory_reserved() / (1024 * 1024), 2) + torch.cuda.reset_peak_memory_stats() + + os.makedirs(os.path.dirname(args.telemetry_file) or ".", exist_ok=True) + with open(args.telemetry_file, "a", encoding="utf-8") as f: + f.write(json.dumps(telemetry_data) + "\n") + zero_grad_all() + + # === EMA Update === + if ema_enabled: + with torch.no_grad(): + update_ema_fused(ema_tensors_list, model_tensors_list, ema_decay) + #Eager mode is slow due to Python loop and multiple dispatch on each parameter, so we use a fused CUDA kernel instead. The commented code below is the naive eager implementation for reference. + # for name, p in base_model.state_dict().items(): + # # ema_weight = decay * ema_weight + (1 - decay) * current_weight + # ema_state[name].mul_(ema_decay).add_(p.detach().float(), alpha=1.0 - ema_decay) + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{hard_step_limit} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Determine whether we've hit the configured stop budget. + if use_steps_stop: + reached_cap = step >= hard_step_limit + else: + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + if ema_enabled: + log0("Applying EMA weights for final evaluation...") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + model_path = os.path.join(args.output_dir, "final_model.pt") if args.output_dir else "final_model.pt" + if master_process: + torch.save(base_model.state_dict(), model_path) + model_bytes = os.path.getsize(model_path) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + quant_model_path = os.path.join(args.output_dir, "final_model.int8.ptz") if args.output_dir else "final_model.int8.ptz" + if master_process: + with open(quant_model_path, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_model_path) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_model_path, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + if prof is not None: + prof.stop() + + +if __name__ == "__main__": + main() + diff --git a/records/track_non_record_16mb/2026-04-30_SP8192_PPMMix_608d_ValueResid/README.md b/records/track_non_record_16mb/2026-04-30_SP8192_PPMMix_608d_ValueResid/README.md new file mode 100644 index 0000000000..a15fc04447 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_SP8192_PPMMix_608d_ValueResid/README.md @@ -0,0 +1,347 @@ +# SP8192 + Value Residual + Byte-Level PPM Mixture + +## Overview + +This submission is the result of an incremental research process rather than a single clean-sheet design. +The codebase was built step by step across many rounds of experiments. Instead of hard-coding one architecture, we intentionally exposed most research ideas as environment-controlled switches so that we could run controlled ablations quickly and compare alternatives under the same training and evaluation framework. + +The final code therefore serves two purposes: + +1. **A trainable compression model** +2. **A flexible experiment platform** for architecture, optimization, tokenizer, evaluation, and mixture research + +Our final strongest submission combines: + +- **SentencePiece 8192 tokenizer** +- **9-layer Transformer backbone** +- **BiFPN2 / XSA / N-gram baseline stack** +- **Value Residual in the last 2 layers** +- **Larger-capacity backbone (608d / MLP×3)** +- **Byte-level PPM mixture at evaluation time** + +The best reproduced result in this branch is: + +- **Neural roundtrip exact val_bpb:** `1.15864608` +- **PPM mixture val_bpb:** `0.832925` + +This is a **non-record / unlimited-compute style submission**. +It is intended to document the method and results clearly. It is **not claimed here as a 10-minute 8×H100 record-track run**. + +--- + +## Design Philosophy + +A major goal of this project was to avoid baking one fragile idea directly into the model. +Instead, we built a single training script with many research switches so we could answer questions like: + +- Does a tokenizer change matter more than a block change? +- Does capacity help more than architectural novelty? +- Is value-path routing more useful than additional parallel branches? +- Are small state-space side lanes actually more efficient than attention here? +- Can a lossless/statistical mixture dominate pure-neural improvements? + +Because of that, the code includes toggles for: + +- tokenizer-dependent training and evaluation +- multiple skip/fusion schemes +- XSA, V-skip, cross-layer V and KV sharing +- PLE +- MTP +- N-gram augmentation and fade scheduling +- depth recurrence +- value residual +- parallel residual v1 and parallel residual v2 +- gated linear / conv-gate / tiny SSM side lanes +- LoRA-TTT +- byte-level PPM mixture evaluation + +Most of these ideas were kept behind flags so they could be turned on and off in ablation sweeps without rewriting the training loop. + +--- + +## Code Structure and Research Platform + +The final script is not a minimal competition-only script. It is a research scaffold that gradually accumulated features as experiments progressed. + +### Core training infrastructure +The script includes: + +- distributed training support +- Muon optimizer for matrix parameters +- Adam for scalar/control parameters +- EMA +- late QAT +- tokenizer-aware val_bpb evaluation +- sliding-window evaluation +- telemetry and profiling hooks + +### Architecture flags +The following major research directions are supported through hyperparameter switches: + +- `BIFPN2_MODE` +- `XSA_ENABLED` +- `NGRAM_*` +- `VALUE_RESIDUAL_*` +- `CROSS_LAYER_V_*` +- `CROSS_LAYER_KV_*` +- `PLE_*` +- `PARALLEL_RESIDUAL_*` +- `PARALLEL_V2_*` +- `DEPTH_RECUR_*` +- `TTT_*` +- `LORA_TTT_*` +- `PPM_*` + +This allowed us to run many controlled sweeps without changing the surrounding code. + +--- + +## Experimental Path + +Our final result did **not** come from one idea. +It came from a sequence of findings. + +### 1. Tokenizer scaling mattered a lot +Early SP1024 experiments consistently plateaued around the high `1.27x` range. +Moving to larger tokenizers gave immediate gains: + +- **SP1024:** roughly `~1.27` +- **SP4096:** roughly `~1.24` +- **SP8192:** roughly `~1.22` before stronger backbone tuning + +This showed that tokenization efficiency was one of the highest-leverage early improvements. + +### 2. Capacity still mattered strongly +We then tested model-capacity changes around the strongest SP8192 line. + +Representative pure-neural results: + +- `512d / mlp2` class: around `~1.21` +- `512d / mlp3`: around `~1.19` +- `576d / mlp3`: around `~1.17` +- `608d / mlp3`: **`1.1587258`** +- `576d / mlp4`: **`1.15911626`** + +This strongly suggested that in our regime, increasing effective model capacity was still more valuable than adding many exotic modules. + +### 3. Value Residual became the strongest architectural improvement +Across many rounds of ablations, **Value Residual** was the most consistent structural gain. + +A representative comparison: + +- baseline strongline (`qk400`): about `1.1794` +- `value_resid_last2`: about `1.1712` + +This was a clear and stable gain. +Further ablations showed: + +- `last2` worked better than `last4` or `last6` +- moderate value blending was better than aggressive late-layer replacement +- the gain remained strong on larger-capacity SP8192 backbones + +This became the main architectural direction that survived repeated testing. + +### 4. Many other structural ideas were explored, but did not become the mainline +We tested a wide range of alternatives: + +- PLE +- parallel residual v1 +- parallel residual v2 +- gated merge variants +- cross-layer V residual variants +- cross-layer KV sharing +- tiny SSM / Mamba-like side lanes +- conv-gate side lanes +- LoRA-TTT +- depth recurrence + +Some produced small gains in isolated runs, but none were as consistently useful as: + +- tokenizer scaling +- capacity scaling +- value residual + +In particular, small SSM-style side lanes did **not** outperform a stronger conventional backbone in our experiments. + +### 5. Byte-level PPM mixture changed the regime +The largest jump came when we moved beyond pure-neural evaluation and added a **byte-level PPM mixture**. + +We first validated that the mixture was real rather than noise. +For example, with a strong 608d / mlp3 / value-residual backbone: + +- neural roundtrip exact: about `1.1589` +- PPM mixture (`order=5`, `thr=0.9`, `lo=0.10`, `hi=0.80`): about `0.9385` + +We then investigated lambda sensitivity and found a stronger regime: + +- `hi=0.775`: `0.883714` +- `hi=0.75`: `0.832925` + +The `0.832925` result was repeated and therefore treated as real, not as a one-off anomaly. + +This was the point where the project moved from “improving the neural backbone” to “combining a strong neural model with a byte-level statistical corrector.” + +--- + +## Summary of Key Experimental Findings + +### Strong positive findings +- Larger tokenizer vocabularies helped substantially +- Capacity scaling remained very effective +- Value Residual was the strongest consistent architecture change +- Byte-level PPM mixture produced the largest overall gain + +### Weak or inconsistent findings +- PLE sometimes helped slightly, but did not remain on the mainline +- Parallel residual variants were at best marginal on top of value residual +- Cross-layer V and KV sharing were not strong mainline improvements + +### Negative or deprioritized findings +- gated parallel merge +- conv-gate side lanes +- small SSM/Mamba-inspired side lanes +- current LoRA-TTT variants for mainline use +- depth recurrence in its current form + +--- + +## Final Mainline Configuration + +The final strongest reproduced backbone used: + +- **Tokenizer:** SentencePiece 8192 +- **Layers:** 9 +- **Model dimension:** 608 +- **Heads:** 8 +- **KV heads:** 4 +- **MLP multiplier:** 3 +- **QK gain init:** 4.0 +- **BiFPN2:** enabled +- **XSA:** enabled in the last 4 layers +- **N-gram features:** enabled +- **Value Residual:** enabled in the last 2 layers +- **EMA:** enabled +- **Late QAT:** enabled + +The final strongest mixture used: + +- **PPM enabled** +- **Order:** 5 +- **Confidence threshold:** 0.9 +- **Lambda low:** 0.10 +- **Lambda high:** 0.75 +- **Neural byte projection:** `spread_root` + +--- + +## Why the Code Has So Many Switches + +The code may appear larger and more feature-heavy than a minimal submission script. +This is intentional. + +We were not optimizing only for final compactness during research. +We were optimizing for: + +- fast iteration +- controlled ablation +- fair comparisons between ideas +- reuse of one stable training loop +- reproducible experiment sweeps + +This let us test ideas without changing unrelated components. +For example, we could compare: + +- tokenizer changes vs architecture changes +- value residual vs parallel residual +- MLP capacity vs SSM side lanes +- neural-only vs byte-level mixture + +using the same general framework. + +In practice, this made it much easier to discover which ideas were truly load-bearing. + +--- + +## Lessons Learned + +### 1. Simple capacity improvements beat many clever block modifications +A stronger standard backbone often outperformed more exotic second-lane or recurrent additions. + +### 2. Value-path routing mattered more than many alternative residual tricks +Value Residual consistently helped more than most parallel or side-lane variants. + +### 3. Tokenization and byte-level evaluation are first-class concerns in this benchmark +This benchmark is not only about building a stronger LM. +Tokenizer efficiency and byte-level correction matter enormously. + +### 4. System-level methods can dominate pure-neural improvements +The transition from pure-neural `~1.1586` to mixed `0.832925` was much larger than any single block-level improvement. +This suggests that at the current frontier, system design is at least as important as backbone design. + +--- + +## Reproduction Notes + +This submission was built and tested through multiple sweeps using environment-variable controlled configs. + +Representative final backbone: +- `MODEL_DIM=608` +- `NUM_HEADS=8` +- `NUM_KV_HEADS=4` +- `MLP_MULT=3` +- `VALUE_RESIDUAL_ENABLED=1` +- `VALUE_RESIDUAL_LAST_N_LAYERS=2` +- `QK_GAIN_INIT=4.0` + +Representative final mixture: +- `PPM_ENABLED=1` +- `PPM_ORDER=5` +- `PPM_CONF_THRESHOLD=0.9` +- `LAMBDA_LO=0.10` +- `LAMBDA_HI=0.75` + +The strongest reproduced run in this folder is the repeated `hi=0.75` configuration. + +--- + +## Submission Status + +This folder documents a **non-record / unlimited-compute** style submission. + +It is intended to capture: + +- the final strongest reproduced method +- the progression of the experimental mainline +- the code path that enabled the result + +It should be read as a record of the method and its experimental evolution, rather than as a claim of record-track compliance. + +--- + +## Included Files + +This folder contains: + +- `train_gpt.py` — experiment and training script with all major research switches +- `submission.json` — metadata and best-result summary +- `config.json` — final selected configuration +- `seed_runs.csv` — representative run summary +- `train.log` — log from the final best reproduced run +- `requirements.txt` — Python dependencies + +--- + +## Final Result + +### Best reproduced pure-neural score +- **`1.15864608`** roundtrip exact val_bpb + +### Best reproduced mixed score +- **`0.832925`** ppm_mix_bpb + +This final result emerged from a long sequence of ablations, with the most important steps being: + +1. tokenizer scaling +2. capacity scaling +3. value residual +4. byte-level PPM mixture diff --git a/records/track_non_record_16mb/2026-04-30_SP8192_PPMMix_608d_ValueResid/config.json b/records/track_non_record_16mb/2026-04-30_SP8192_PPMMix_608d_ValueResid/config.json new file mode 100644 index 0000000000..e20cafd773 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_SP8192_PPMMix_608d_ValueResid/config.json @@ -0,0 +1,111 @@ +{ + "DATA_PATH": "./data/fineweb_multi/datasets/fineweb10B_sp8192", + "TOKENIZER_PATH": "./data/fineweb_multi/tokenizers/fineweb_8192_bpe.model", + "NUM_LAYERS": 9, + "MODEL_DIM": 608, + "NUM_HEADS": 8, + "NUM_KV_HEADS": 4, + "MLP_MULT": 3, + "VOCAB_SIZE": 8192, + "TIE_EMBEDDINGS": 1, + "ROPE_BASE": 10000.0, + "ROPE_DIMS": -1, + "LEARNABLE_ROPE": 0, + "LOGIT_SOFTCAP": 30.0, + "QK_GAIN_INIT": 4.0, + "GRAD_ACCUM_STEPS": 4, + "TRAIN_BATCH_TOKENS": 786432, + "TRAIN_SEQ_LEN": 1024, + "ITERATIONS": 20000, + "WARMUP_STEPS": 20, + "WARMDOWN_ITERS": 1200, + "MAX_WALLCLOCK_SECONDS": 3600.0, + "VAL_BATCH_SIZE": 786432, + "VAL_LOSS_EVERY": 1000, + "TRAIN_LOG_EVERY": 200, + "MATRIX_LR": 0.04, + "SCALAR_LR": 0.04, + "EMBED_LR": 0.6, + "HEAD_LR": 0.008, + "TIED_EMBED_LR": 0.05, + "TIED_EMBED_INIT_STD": 0.005, + "MUON_MOMENTUM": 0.95, + "MUON_BACKEND_STEPS": 5, + "MUON_MOMENTUM_WARMUP_START": 0.85, + "MUON_MOMENTUM_WARMUP_STEPS": 500, + "BETA1": 0.9, + "BETA2": 0.95, + "ADAM_EPS": 1e-08, + "GRAD_CLIP_NORM": 0.0, + "FDA_MODE": 0, + "BIFPN_MODE": 0, + "BIFPN2_MODE": 1, + "BIFPN_GROUP_COUNT": 8, + "BIFPN_BAND_WIDTH": 1, + "BIFPN_NORM_EPS": 0.0001, + "BIFPN_INIT_MAIN": 1.0, + "BIFPN_INIT_NEIGHBOR": 0.15, + "BIFPN_INIT_FAR": 0.0, + "SMEAR_MODE": 0, + "SMEAR_WINDOW": 4, + "SMEAR_GATE": 0, + "LN_SCALE": 1, + "LEARNABLE_LN_SCALE": 0, + "AFFINE_NORM": 0, + "SCALEDLM_HEAD": 1, + "XSA_ENABLED": 1, + "XSA_LAST_N_LAYERS": 4, + "XSA_EPS": 1e-06, + "V_SKIP_ENABLED": 0, + "V_SKIP_LAST_N_LAYERS": 0, + "V_SKIP_MODE": "scalar", + "V_SKIP_GROUP_COUNT": 8, + "CROSS_LAYER_V_ENABLED": 0, + "CROSS_LAYER_V_LAST_N_LAYERS": 4, + "CROSS_LAYER_V_MODE": "residual", + "CROSS_LAYER_V_GROUP_COUNT": 4, + "CROSS_LAYER_KV_SHARING_ENABLED": 0, + "CROSS_LAYER_KV_LAST_N_LAYERS": 0, + "CROSS_LAYER_KV_SHARE_K": 1, + "CROSS_LAYER_KV_SHARE_V": 1, + "CROSS_LAYER_KV_PAIRWISE": 0, + "CROSS_LAYER_KV_PARTIAL_HEAD": 0, + "CROSS_LAYER_KV_PARTIAL_HEAD_COUNT": 2, + "VALUE_RESIDUAL_ENABLED": 1, + "VALUE_RESIDUAL_LAST_N_LAYERS": 2, + "VALUE_RESIDUAL_INIT_V0": 0.5, + "VALUE_RESIDUAL_INIT_CUR": 0.5, + "PLE_ENABLED": 0, + "MTP_NUM_HEADS": 0, + "NGRAM_VOCAB_SIZE": 8192, + "NGRAM_DIM": 128, + "NGRAM_MAX_N": 2, + "NGRAM_FADE_ENABLE": 1, + "NGRAM_FADE_START_FRAC": 0.15, + "NGRAM_FADE_END_FRAC": 0.45, + "NGRAM_FADE_MIN_SCALE": 0.0, + "EMA_ENABLED": 1, + "EMA_DECAY": 0.997, + "LATE_QAT_RATIO": 0.15, + "DYNAMIC_CLIP_PERCENTILES": "100.0,99.9999,99.9995,99.995,99.99,99.95,99.9,99.8", + "EVAL_USE_SLIDING_WINDOW": 0, + "EVAL_STRIDE": 1024, + "EVAL_BATCH_SEQS": 16, + "TELEMETRY_EVERY": 50, + "PROFILE_RUN": 0, + "PROFILE_WARMUP_STEPS": 5, + "PROFILE_ACTIVE_STEPS": 10, + "TTT_ENABLED": 0, + "TTT_MODE": "lora", + "LORA_TTT_ENABLED": 0, + "PPM_ENABLED": 1, + "PPM_ORDER": 5, + "PPM_SUBSET_TOKENS": 8000000, + "PPM_CONF_THRESHOLD": 0.9, + "LAMBDA_LO": 0.1, + "LAMBDA_HI": 0.75, + "NN_BYTE_PROJECTION": "spread_root", + "NN_BYTE_UNIFORM_FLOOR": 1e-06, + "STOP_MODE": "steps", + "MAX_TRAIN_STEPS": 3000 +} \ No newline at end of file diff --git a/records/track_non_record_16mb/2026-04-30_SP8192_PPMMix_608d_ValueResid/requirements.txt b/records/track_non_record_16mb/2026-04-30_SP8192_PPMMix_608d_ValueResid/requirements.txt new file mode 100644 index 0000000000..911b0e52f0 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_SP8192_PPMMix_608d_ValueResid/requirements.txt @@ -0,0 +1,10 @@ +numpy +tqdm +torch +huggingface-hub +kernels +setuptools +typing-extensions==4.15.0 +datasets +tiktoken +sentencepiece \ No newline at end of file diff --git a/records/track_non_record_16mb/2026-04-30_SP8192_PPMMix_608d_ValueResid/seed_runs.csv b/records/track_non_record_16mb/2026-04-30_SP8192_PPMMix_608d_ValueResid/seed_runs.csv new file mode 100644 index 0000000000..4b90f5817d --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_SP8192_PPMMix_608d_ValueResid/seed_runs.csv @@ -0,0 +1,3 @@ +experiment,seed,last_val_bpb,roundtrip_val_bpb,roundtrip_exact_val_bpb,submission_bytes,stopped_step,output_dir +verify_hi075_repeat,,1.1577,1.1586,1.15864608,,3000,output/run_sweep_final_verify_hi075_v1/verify_hi075_repeat_20260430_174240 +verify_hi0775_sanity,,1.1578,1.1588,1.15875172,,3000,output/run_sweep_final_verify_hi075_v1/verify_hi0775_sanity_20260430_190707 diff --git a/records/track_non_record_16mb/2026-04-30_SP8192_PPMMix_608d_ValueResid/submission.json b/records/track_non_record_16mb/2026-04-30_SP8192_PPMMix_608d_ValueResid/submission.json new file mode 100644 index 0000000000..725b58e8b4 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_SP8192_PPMMix_608d_ValueResid/submission.json @@ -0,0 +1,31 @@ +{ + "title": "SP8192 + Value Residual + PPM Mixture", + "author": "Kaikai Liu", + "github_id": "lkk688", + "track": "non-record", + "compute_track": "unlimited", + "artifact_limit_mb": 16, + "record_eligible": false, + "val_bpb": 0.832925, + "neural_roundtrip_val_bpb": 1.15864608, + "tokenizer": "SentencePiece 8192", + "model_summary": { + "num_layers": 9, + "model_dim": 608, + "num_heads": 8, + "num_kv_heads": 4, + "mlp_mult": 3, + "qk_gain_init": 4.0, + "value_residual_enabled": true, + "value_residual_last_n_layers": 2 + }, + "mixture_summary": { + "ppm_enabled": true, + "ppm_order": 5, + "ppm_conf_threshold": 0.9, + "lambda_lo": 0.10, + "lambda_hi": 0.75, + "nn_byte_projection": "spread_root" + }, + "notes": "Non-record unlimited-compute submission. Best reproduced mixture result on SP8192 backbone with value residual and byte-level PPM mixture." +} \ No newline at end of file diff --git a/records/track_non_record_16mb/2026-04-30_SP8192_PPMMix_608d_ValueResid/train.log b/records/track_non_record_16mb/2026-04-30_SP8192_PPMMix_608d_ValueResid/train.log new file mode 100644 index 0000000000..386a7ce4c7 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_SP8192_PPMMix_608d_ValueResid/train.log @@ -0,0 +1,85 @@ +output/run_sweep_final_verify_hi075_v1/verify_hi075_repeat_20260430_174240/20260430_174244.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/fineweb_multi/tokenizers/fineweb_8192_bpe.model +train_loader:dataset:fineweb10B_sp8192 train_shards:128 +val_loader:shards pattern=./data/fineweb_multi/datasets/fineweb10B_sp8192/fineweb_val_*.bin tokens:40541184 +Architecture: Discrete N-Gram Hash (Max N=2) +lora_params:0 +model_params:36072045 +world_size:1 grad_accum_steps:4 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True matrix_lr:0.04 scalar_lr:0.04 +ttt_enabled:False ttt_mode:lora lora_ttt_enabled:False +parallel_v2_enabled:0 mode:dual_add second_lane:mlp active_layers:[] second_lane_params:0 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +EMA Enabled: decay=0.997 +Scheduled Late QAT to start at step 2550 (last 15.0%) +step:0/3000 val_loss:9.0104 val_bpb:3.4882 train_time:4ms step_avg:4.47ms +step:1/3000 train_loss:9.0104 train_time:1907ms step_avg:1906.69ms +step:2/3000 train_loss:8.8420 train_time:3356ms step_avg:1678.15ms +step:3/3000 train_loss:8.3869 train_time:4819ms step_avg:1606.49ms +step:4/3000 train_loss:7.9243 train_time:6309ms step_avg:1577.21ms +step:5/3000 train_loss:7.4828 train_time:7777ms step_avg:1555.48ms +step:6/3000 train_loss:7.1932 train_time:9239ms step_avg:1539.77ms +step:7/3000 train_loss:7.0314 train_time:10699ms step_avg:1528.37ms +step:8/3000 train_loss:6.8973 train_time:12158ms step_avg:1519.69ms +step:9/3000 train_loss:6.6882 train_time:13619ms step_avg:1513.27ms +step:10/3000 train_loss:6.6433 train_time:15079ms step_avg:1507.94ms +step:200/3000 train_loss:3.7348 train_time:293462ms step_avg:1467.31ms +step:400/3000 train_loss:3.4740 train_time:585623ms step_avg:1464.06ms +step:600/3000 train_loss:3.3361 train_time:877863ms step_avg:1463.11ms +step:800/3000 train_loss:3.3155 train_time:1170127ms step_avg:1462.66ms +step:1000/3000 train_loss:3.1813 train_time:1462346ms step_avg:1462.35ms +step:1000/3000 val_loss:3.2105 val_bpb:1.2429 train_time:1462351ms step_avg:1462.35ms +step:1200/3000 train_loss:3.1663 train_time:1754397ms step_avg:1462.00ms +step:1400/3000 train_loss:3.1046 train_time:2046459ms step_avg:1461.76ms +step:1600/3000 train_loss:3.1987 train_time:2338480ms step_avg:1461.55ms +step:1800/3000 train_loss:3.1749 train_time:2630562ms step_avg:1461.42ms +step:2000/3000 train_loss:3.1123 train_time:2922623ms step_avg:1461.31ms +step:2000/3000 val_loss:3.0863 val_bpb:1.1948 train_time:2922627ms step_avg:1461.31ms +step:2200/3000 train_loss:3.0783 train_time:3214662ms step_avg:1461.21ms +step:2400/3000 train_loss:3.0946 train_time:3506663ms step_avg:1461.11ms +[Step 2550] Activating Late QAT — enabling branchless STE quantization. +step:2600/3000 train_loss:2.9791 train_time:3798704ms step_avg:1461.04ms +step:2800/3000 train_loss:3.0690 train_time:4090728ms step_avg:1460.97ms +step:3000/3000 train_loss:2.9329 train_time:4382728ms step_avg:1460.91ms +step:3000/3000 val_loss:2.9904 val_bpb:1.1577 train_time:4382736ms step_avg:1460.91ms +peak memory allocated: 49017 MiB reserved: 54226 MiB +Applying EMA weights for final evaluation... +final_int8_zlib_roundtrip val_loss:2.9929 val_bpb:1.1586 +final_int8_zlib_roundtrip_exact val_loss:2.99287347 val_bpb:1.15864608 +Starting PPM byte mixture evaluation... +ppm_mix_progress seq:500/39591 tokens:513024 bytes:1884087 contexts:420910 skipped_ctx:0 mix_bpb:0.780810 +ppm_mix_progress seq:1000/39591 tokens:1025024 bytes:3749771 contexts:627451 skipped_ctx:0 mix_bpb:0.797350 +ppm_mix_progress seq:1500/39591 tokens:1537024 bytes:5612262 contexts:794233 skipped_ctx:0 mix_bpb:0.806758 +ppm_mix_progress seq:2000/39591 tokens:2049024 bytes:7463332 contexts:939360 skipped_ctx:0 mix_bpb:0.810462 +ppm_mix_progress seq:2500/39591 tokens:2561024 bytes:9321552 contexts:1074000 skipped_ctx:0 mix_bpb:0.814928 +ppm_mix_progress seq:3000/39591 tokens:3073024 bytes:11211930 contexts:1184575 skipped_ctx:0 mix_bpb:0.818434 +ppm_mix_progress seq:3500/39591 tokens:3585024 bytes:13070368 contexts:1299746 skipped_ctx:0 mix_bpb:0.821272 +ppm_mix_progress seq:4000/39591 tokens:4097024 bytes:14918758 contexts:1414474 skipped_ctx:0 mix_bpb:0.823674 +ppm_mix_progress seq:4500/39591 tokens:4609024 bytes:16791448 contexts:1509685 skipped_ctx:0 mix_bpb:0.825507 +ppm_mix_progress seq:5000/39591 tokens:5121024 bytes:18644054 contexts:1606254 skipped_ctx:0 mix_bpb:0.827103 +ppm_mix_progress seq:5500/39591 tokens:5633024 bytes:20513194 contexts:1693750 skipped_ctx:0 mix_bpb:0.828218 +ppm_mix_progress seq:6000/39591 tokens:6145024 bytes:22391128 contexts:1772025 skipped_ctx:0 mix_bpb:0.829799 +ppm_mix_progress seq:6500/39591 tokens:6657024 bytes:24298727 contexts:1849517 skipped_ctx:0 mix_bpb:0.830862 +ppm_mix_progress seq:7000/39591 tokens:7169024 bytes:26194772 contexts:1923257 skipped_ctx:0 mix_bpb:0.831779 +ppm_mix_progress seq:7500/39591 tokens:7681024 bytes:28109278 contexts:1994347 skipped_ctx:0 mix_bpb:0.832708 +ppm_mix_bpb:0.832925 diff --git a/records/track_non_record_16mb/2026-04-30_SP8192_PPMMix_608d_ValueResid/train_gpt.py b/records/track_non_record_16mb/2026-04-30_SP8192_PPMMix_608d_ValueResid/train_gpt.py new file mode 100644 index 0000000000..2cde79e519 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-30_SP8192_PPMMix_608d_ValueResid/train_gpt.py @@ -0,0 +1,3389 @@ +from __future__ import annotations + +""" +mytrain_gpt_v5_lora_ttt.py + +V5 mainline focused on the highest-priority path: + 1) Keep the current strong baseline backbone options. + 2) Add legal score-first LoRA-TTT as the main new feature. + 3) Keep the implementation torch.compile-friendly during normal training. + +This file is intended as a direct evolution target from your v4 script. +It contains the new flags, LoRA modules, TTT adapter wiring, and integration +points you can merge into the existing code. + +Design goals: +- Normal train/inference path remains compile-friendly. +- TTT path is isolated and runs after final int8 roundtrip eval. +- LoRA parameters are only used when TTT is enabled. +- Supports warm-start A / reset B, alpha/rank scaling, independent WD. +- Supports score-first legality: score a chunk first, then update on it. +""" + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torch.profiler +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +import json + + +# ============================================================ +# HYPERPARAMETERS +# ============================================================ + +class Hyperparameters: + # ----------------------------- + # Data + # ----------------------------- + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # ----------------------------- + # Validation / logging + # ----------------------------- + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # ----------------------------- + # Training schedule + # ----------------------------- + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + stop_mode = os.environ.get("STOP_MODE", "walltime") + max_train_steps = int(os.environ.get("MAX_TRAIN_STEPS", "0")) + + # ----------------------------- + # Sliding eval + # ----------------------------- + eval_use_sliding_window = bool(int(os.environ.get("EVAL_USE_SLIDING_WINDOW", "0"))) + eval_stride = int(os.environ.get("EVAL_STRIDE", "128")) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", "16")) + + # ----------------------------- + # Core model + # ----------------------------- + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + rope_dims = int(os.environ.get("ROPE_DIMS", "-1")) + learnable_rope = bool(int(os.environ.get("LEARNABLE_ROPE", "0"))) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # ----------------------------- + # Optimizer + # ----------------------------- + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # ----------------------------- + # Telemetry / profile + # ----------------------------- + output_dir = os.environ.get("OUTPUT_DIR", "") + telemetry_every = int(os.environ.get("TELEMETRY_EVERY", "0")) + telemetry_file = os.environ.get( + "TELEMETRY_FILE", + os.path.join(output_dir, "telemetry.jsonl") if output_dir else "logs/telemetry.jsonl", + ) + profile_run = bool(int(os.environ.get("PROFILE_RUN", "0"))) + profile_warmup_steps = int(os.environ.get("PROFILE_WARMUP_STEPS", "5")) + profile_active_steps = int(os.environ.get("PROFILE_ACTIVE_STEPS", "10")) + profile_step = int(os.environ.get("PROFILE_STEP", "-1")) + profile_output_dir = os.environ.get("PROFILE_OUTPUT_DIR", "output/prof_base") + + # ----------------------------- + # Baseline architecture flags + # ----------------------------- + fda_mode = bool(int(os.environ.get("FDA_MODE", "0"))) + bifpn_mode = bool(int(os.environ.get("BIFPN_MODE", "0"))) + bifpn2_mode = bool(int(os.environ.get("BIFPN2_MODE", "0"))) + bifpn_group_count = int(os.environ.get("BIFPN_GROUP_COUNT", "8")) + bifpn_band_width = int(os.environ.get("BIFPN_BAND_WIDTH", "1")) + bifpn_norm_eps = float(os.environ.get("BIFPN_NORM_EPS", "1e-4")) + bifpn_init_main = float(os.environ.get("BIFPN_INIT_MAIN", "1.0")) + bifpn_init_neighbor = float(os.environ.get("BIFPN_INIT_NEIGHBOR", "0.15")) + bifpn_init_far = float(os.environ.get("BIFPN_INIT_FAR", "0.0")) + + smear_mode = bool(int(os.environ.get("SMEAR_MODE", "0"))) + smear_window = int(os.environ.get("SMEAR_WINDOW", "4")) + smear_gate = bool(int(os.environ.get("SMEAR_GATE", "0"))) + + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + learnable_ln_scale = bool(int(os.environ.get("LEARNABLE_LN_SCALE", "0"))) + affine_norm = bool(int(os.environ.get("AFFINE_NORM", "0"))) + scaledlm_head = bool(int(os.environ.get("SCALEDLM_HEAD", "1"))) + + xsa_enabled = bool(int(os.environ.get("XSA_ENABLED", "0"))) + xsa_last_n_layers = int(os.environ.get("XSA_LAST_N_LAYERS", "0")) + xsa_eps = float(os.environ.get("XSA_EPS", "1e-6")) + + v_skip_enabled = bool(int(os.environ.get("V_SKIP_ENABLED", "0"))) + v_skip_last_n_layers = int(os.environ.get("V_SKIP_LAST_N_LAYERS", "0")) + v_skip_mode = os.environ.get("V_SKIP_MODE", "scalar") + v_skip_group_count = int(os.environ.get("V_SKIP_GROUP_COUNT", "8")) + + cross_layer_v_enabled = bool(int(os.environ.get("CROSS_LAYER_V_ENABLED", "0"))) + cross_layer_v_last_n_layers = int(os.environ.get("CROSS_LAYER_V_LAST_N_LAYERS", "0")) + cross_layer_v_mode = os.environ.get("CROSS_LAYER_V_MODE", "residual") + cross_layer_v_group_count = int(os.environ.get("CROSS_LAYER_V_GROUP_COUNT", "8")) + + cross_layer_kv_sharing_enabled = bool(int(os.environ.get("CROSS_LAYER_KV_SHARING_ENABLED", "0"))) + cross_layer_kv_last_n_layers = int(os.environ.get("CROSS_LAYER_KV_LAST_N_LAYERS", "0")) + cross_layer_kv_share_k = bool(int(os.environ.get("CROSS_LAYER_KV_SHARE_K", "1"))) + cross_layer_kv_share_v = bool(int(os.environ.get("CROSS_LAYER_KV_SHARE_V", "1"))) + cross_layer_kv_pairwise = bool(int(os.environ.get("CROSS_LAYER_KV_PAIRWISE", "0"))) + cross_layer_kv_partial_head = bool(int(os.environ.get("CROSS_LAYER_KV_PARTIAL_HEAD", "0"))) + cross_layer_kv_partial_head_count = int(os.environ.get("CROSS_LAYER_KV_PARTIAL_HEAD_COUNT", "2")) + cross_layer_kv_source_mode = os.environ.get("CROSS_LAYER_KV_SOURCE_MODE", "previous") + + # ----------------------------- + # Depth recurrence / value residual + # ----------------------------- + depth_recur_enabled = bool(int(os.environ.get("DEPTH_RECUR_ENABLED", "0"))) + num_stem_blocks = int(os.environ.get("NUM_STEM_BLOCKS", "3")) + num_core_blocks = int(os.environ.get("NUM_CORE_BLOCKS", "3")) + num_core_repeats = int(os.environ.get("NUM_CORE_REPEATS", "3")) + + value_residual_enabled = bool(int(os.environ.get("VALUE_RESIDUAL_ENABLED", "0"))) + value_residual_last_n_layers = int(os.environ.get("VALUE_RESIDUAL_LAST_N_LAYERS", "0")) + value_residual_init_v0 = float(os.environ.get("VALUE_RESIDUAL_INIT_V0", "0.5")) + value_residual_init_cur = float(os.environ.get("VALUE_RESIDUAL_INIT_CUR", "0.5")) + + # ----------------------------- + # PLE + # ----------------------------- + ple_enabled = bool(int(os.environ.get("PLE_ENABLED", "0"))) + ple_temporal_conv = bool(int(os.environ.get("PLE_TEMPORAL_CONV", "0"))) + ple_dim = int(os.environ.get("PLE_DIM", "32")) + ple_mode = os.environ.get("PLE_MODE", "post_attn") + ple_token_scale_init = float(os.environ.get("PLE_TOKEN_SCALE_INIT", "1.0")) + ple_ctx_scale_init = float(os.environ.get("PLE_CTX_SCALE_INIT", "1.0")) + ple_resid_scale_init = float(os.environ.get("PLE_RESID_SCALE_INIT", "0.1")) + + # ----------------------------- + # MTP + # ----------------------------- + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", "0")) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", "0.2")) + mtphead_mlpmode = bool(int(os.environ.get("MTPHEAD_MLPMODE", "0"))) + + # ----------------------------- + # N-gram + # ----------------------------- + ngram_vocab_size = int(os.environ.get("NGRAM_VOCAB_SIZE", "2048")) + ngram_dim = int(os.environ.get("NGRAM_DIM", "128")) + ngram_max_n = int(os.environ.get("NGRAM_MAX_N", "4")) + ngram_fade_enable = bool(int(os.environ.get("NGRAM_FADE_ENABLE", "0"))) + ngram_fade_start_frac = float(os.environ.get("NGRAM_FADE_START_FRAC", "0.15")) + ngram_fade_end_frac = float(os.environ.get("NGRAM_FADE_END_FRAC", "0.45")) + ngram_fade_min_scale = float(os.environ.get("NGRAM_FADE_MIN_SCALE", "0.0")) + + # ----------------------------- + # EMA / QAT + # ----------------------------- + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "0"))) + ema_decay = float(os.environ.get("EMA_DECAY", "0.997")) + late_qat_ratio = float(os.environ.get("LATE_QAT_RATIO", "0.15")) + dynamic_clip_percentiles = tuple( + float(x.strip()) + for x in os.environ.get( + "DYNAMIC_CLIP_PERCENTILES", + "100.0,99.9999,99.9995,99.995,99.99,99.95,99.9,99.8", + ).split(",") + if x.strip() + ) + + # ----------------------------- + # V5 MAIN FEATURE: LoRA-TTT + # ----------------------------- + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_mode = os.environ.get("TTT_MODE", "lora") # lora | full + ttt_lr = float(os.environ.get("TTT_LR", "0.002")) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", "1")) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", "49152")) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", "32")) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", "1.0")) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", "0.9")) + ttt_weight_decay = float(os.environ.get("TTT_WEIGHT_DECAY", "1.0")) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", "0")) + + lora_ttt_enabled = bool(int(os.environ.get("LORA_TTT_ENABLED", "0"))) + lora_ttt_rank = int(os.environ.get("LORA_TTT_RANK", "128")) + lora_ttt_alpha = float(os.environ.get("LORA_TTT_ALPHA", "144.0")) + lora_ttt_dropout = float(os.environ.get("LORA_TTT_DROPOUT", "0.0")) + lora_ttt_warm_start_a = bool(int(os.environ.get("LORA_TTT_WARM_START_A", "1"))) + lora_ttt_reset_b_each_chunk = bool(int(os.environ.get("LORA_TTT_RESET_B_EACH_CHUNK", "1"))) + lora_ttt_targets = os.environ.get( + "LORA_TTT_TARGETS", + "attn_q,attn_k,attn_v,attn_proj,mlp_fc,mlp_proj", + ) + + parallel_residual_enabled = bool(int(os.environ.get("PARALLEL_RESIDUAL_ENABLED", "0"))) + parallel_residual_last_n_layers = int(os.environ.get("PARALLEL_RESIDUAL_LAST_N_LAYERS", "0")) + parallel_residual_mode = os.environ.get("PARALLEL_RESIDUAL_MODE", "dual_add") # dual_add | gated_add + parallel_residual_init_attn = float(os.environ.get("PARALLEL_RESIDUAL_INIT_ATTN", "1.0")) + parallel_residual_init_mlp = float(os.environ.get("PARALLEL_RESIDUAL_INIT_MLP", "1.0")) + parallel_residual_gate_init = float(os.environ.get("PARALLEL_RESIDUAL_GATE_INIT", "0.0")) + + # ----------------------------- + # Parallel Residual v2 / hybrid second lane + # ----------------------------- + parallel_v2_enabled = bool(int(os.environ.get("PARALLEL_V2_ENABLED", "0"))) + parallel_v2_last_n_layers = int(os.environ.get("PARALLEL_V2_LAST_N_LAYERS", "2")) + parallel_v2_mode = os.environ.get("PARALLEL_V2_MODE", "dual_add") # dual_add | gated_add | delayed_merge + parallel_v2_second_lane = os.environ.get("PARALLEL_V2_SECOND_LANE", "mlp") # mlp | gated_linear | conv_gate | ssm + parallel_v2_init_attn = float(os.environ.get("PARALLEL_V2_INIT_ATTN", "1.0")) + parallel_v2_init_second = float(os.environ.get("PARALLEL_V2_INIT_SECOND", "1.0")) + parallel_v2_gate_init = float(os.environ.get("PARALLEL_V2_GATE_INIT", "0.0")) + parallel_v2_delayed_merge_steps = int(os.environ.get("PARALLEL_V2_DELAYED_MERGE_STEPS", "1")) + parallel_v2_norm_shared = bool(int(os.environ.get("PARALLEL_V2_NORM_SHARED", "1"))) + parallel_v2_use_post_attn_ple = bool(int(os.environ.get("PARALLEL_V2_USE_POST_ATTN_PLE", "0"))) + parallel_v2_log_norm_ratios = bool(int(os.environ.get("PARALLEL_V2_LOG_NORM_RATIOS", "0"))) + + gated_linear_mult = int(os.environ.get("GATED_LINEAR_MULT", "2")) + gated_linear_zero_init = bool(int(os.environ.get("GATED_LINEAR_ZERO_INIT", "1"))) + conv_gate_kernel_size = int(os.environ.get("CONV_GATE_KERNEL_SIZE", "4")) + ssm_state_dim = int(os.environ.get("SSM_STATE_DIM", "8")) + ssm_expand = int(os.environ.get("SSM_EXPAND", "2")) + ssm_conv_kernel = int(os.environ.get("SSM_CONV_KERNEL", "8")) + ssm_gate = bool(int(os.environ.get("SSM_GATE", "1"))) + + # ----------------------------- + # Lossless / PPM mixture eval + # ----------------------------- + ppm_enabled = bool(int(os.environ.get("PPM_ENABLED", "0"))) + ppm_order = int(os.environ.get("PPM_ORDER", "5")) + ppm_subset_tokens = int(os.environ.get("PPM_SUBSET_TOKENS", "0")) # 0 = full val + ppm_conf_threshold = float(os.environ.get("PPM_CONF_THRESHOLD", "0.9")) + lambda_lo = float(os.environ.get("LAMBDA_LO", "0.05")) + lambda_hi = float(os.environ.get("LAMBDA_HI", "0.9")) + ppm_max_contexts = int(os.environ.get("PPM_MAX_CONTEXTS", "0")) # 0 = unbounded + + # token->byte projection + nn_byte_projection = os.environ.get("NN_BYTE_PROJECTION", "spread_root") + nn_byte_uniform_floor = float(os.environ.get("NN_BYTE_UNIFORM_FLOOR", "1e-6")) + + # ----------------------------- + # Eval-only mode + # ----------------------------- + eval_only = bool(int(os.environ.get("EVAL_ONLY", "0"))) + checkpoint = os.environ.get("CHECKPOINT", "") + + +# ============================================================ +# COMPILE-FRIENDLY HELPERS +# ============================================================ + +@torch.compile(dynamic=False, fullgraph=True) +def update_ema_fused(ema_tensors: list[Tensor], model_tensors: list[Tensor], decay: float): + for e, m in zip(ema_tensors, model_tensors): + e.mul_(decay).add_(m.float(), alpha=1.0 - decay) + + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr: curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr: curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ============================================================ +# TOKENIZER-AGNOSTIC EVAL HELPERS +# ============================================================ + +def build_sentencepiece_luts(sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device): + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + +def build_sentencepiece_byte_tables( + sp: spm.SentencePieceProcessor, + vocab_size: int, +): + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + + token_piece_bytes: list[bytes] = [b""] * table_size + has_leading_space = [False] * table_size + is_boundary_token = [True] * table_size + + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + token_piece_bytes[token_id] = b"" + is_boundary_token[token_id] = True + continue + + is_boundary_token[token_id] = False + + if sp.is_byte(token_id): + piece = sp.id_to_piece(token_id) + # piece like <0xAB> + if piece.startswith("<0x") and piece.endswith(">") and len(piece) == 6: + token_piece_bytes[token_id] = bytes([int(piece[3:5], 16)]) + else: + token_piece_bytes[token_id] = b"" + continue + + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space[token_id] = True + piece = piece[1:] + token_piece_bytes[token_id] = piece.encode("utf-8") + + return token_piece_bytes, has_leading_space, is_boundary_token + + +def reconstruct_token_bytes( + prev_token_id: int, + token_id: int, + token_piece_bytes: list[bytes], + has_leading_space: list[bool], + is_boundary_token: list[bool], +) -> bytes: + base = token_piece_bytes[token_id] + if not base: + return b"" + if has_leading_space[token_id] and not is_boundary_token[prev_token_id]: + return b" " + base + return base + + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def tokens_to_bytes_count(xb: Tensor, yb: Tensor, base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor) -> Tensor: + prev_ids = xb.reshape(-1) + tgt_ids = yb.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + return token_bytes.sum() + + +# ============================================================ +# LoRA-TTT MODULES +# ============================================================ + +class BatchedLinearLoRA(nn.Module): + """ + Compile-friendly normal path: + y = base(x) + TTT path: + y = base(x) + scale * B(A(dropout(x))) + + Important details copied from the strong public PR direction: + - scale = alpha / rank + - A can be warm-started across chunks + - B can be reset each chunk + """ + def __init__(self, base: nn.Module, rank: int, alpha: float, dropout: float = 0.0): + super().__init__() + if not isinstance(base, (nn.Linear, CastedLinear)): + raise TypeError(f"Unsupported base module: {type(base)}") + self.base = base + self.in_features = base.in_features + self.out_features = base.out_features + self.rank = rank + self.alpha = alpha + self.scale = alpha / max(rank, 1) + self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + + self.lora_A = nn.Parameter(torch.empty(rank, self.in_features, dtype=torch.float32)) + self.lora_B = nn.Parameter(torch.empty(self.out_features, rank, dtype=torch.float32)) + self.lora_enabled = False + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + @torch.no_grad() + def reset_B_only(self): + nn.init.zeros_(self.lora_B) + + def forward(self, x: Tensor) -> Tensor: + y = self.base(x) + if not self.lora_enabled: + return y + x_d = self.dropout(x) + a = F.linear(x_d, self.lora_A.to(dtype=x.dtype), bias=None) + b = F.linear(a, self.lora_B.to(dtype=x.dtype), bias=None) + return y + self.scale * b + + +class LoRATTTManager: + def __init__(self, model: nn.Module, args: Hyperparameters): + self.model = model + self.args = args + self.targets = {s.strip() for s in args.lora_ttt_targets.split(",") if s.strip()} + self.adapters: dict[str, BatchedLinearLoRA] = {} + + def _want_module(self, name: str) -> bool: + mapping = { + "attn_q": ".attn.c_q", + "attn_k": ".attn.c_k", + "attn_v": ".attn.c_v", + "attn_proj": ".attn.proj", + "mlp_fc": ".mlp.fc", + "mlp_proj": ".mlp.proj", + } + for k, suffix in mapping.items(): + if k in self.targets and name.endswith(suffix): + return True + return False + + def inject(self): + replacements = [] + for name, module in self.model.named_modules(): + if self._want_module(name) and isinstance(module, (nn.Linear, CastedLinear)): + replacements.append((name, module)) + + for full_name, module in replacements: + parent_name, child_name = full_name.rsplit('.', 1) + parent = self.model.get_submodule(parent_name) + wrapped = BatchedLinearLoRA( + base=module, + rank=self.args.lora_ttt_rank, + alpha=self.args.lora_ttt_alpha, + dropout=self.args.lora_ttt_dropout, + ) + setattr(parent, child_name, wrapped) + self.adapters[full_name] = wrapped + + def set_enabled(self, enabled: bool): + for mod in self.adapters.values(): + mod.lora_enabled = enabled + + def lora_parameters(self): + for mod in self.adapters.values(): + yield mod.lora_A + yield mod.lora_B + + @torch.no_grad() + def reset_chunk_state(self): + for mod in self.adapters.values(): + if self.args.lora_ttt_reset_b_each_chunk: + mod.reset_B_only() + if not self.args.lora_ttt_warm_start_a: + nn.init.kaiming_uniform_(mod.lora_A, a=math.sqrt(5)) + + +# ============================================================ +# MODEL BUILDING BLOCKS +# ============================================================ + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + p for p in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,vr_lambda,parallel_attn_scale,parallel_mlp_scale,parallel_gate,parallel_v2_attn_scale,parallel_v2_second_scale,parallel_v2_gate,ssm_A_log,ssm_B,ssm_C", + ).split(",") if p +) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int | None = None, eps: float | None = None, affine: bool = False): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) if (affine and dim is not None) else None + + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) if self.weight is not None else None + return F.rms_norm(x, (x.size(-1),), weight=w, eps=self.eps) + + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, bias: bool = False): + super().__init__(in_features, out_features, bias=bias) + self.register_buffer("qat_alpha", torch.tensor(0.0, dtype=torch.float32), persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + w = self.weight + w_max = w.detach().abs().amax(dim=1, keepdim=True) + scale = (w_max / 127.0).clamp_min(1e-7) + w_quant = torch.clamp(torch.round(w / scale), -127, 127) * scale + w = w + (self.qat_alpha * (w_quant - w)).detach() + return F.linear(x, w.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype): + if self._cos_cached is None or self._sin_cached is None or self._seq_len_cached != seq_len or self._cos_cached.device != device: + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + half = rope_dims // 2 + x1 = x[..., :half] + x2 = x[..., half:rope_dims] + x_pass = x[..., rope_dims:] + cos_part = cos[..., :half] + sin_part = sin[..., :half] + return torch.cat(( + x1 * cos_part + x2 * sin_part, + x1 * (-sin_part) + x2 * cos_part, + x_pass, + ), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +def _expand_group_gates(g: Tensor, total_dim: int) -> Tensor: + if total_dim % g.numel() != 0: + raise ValueError(f"total_dim ({total_dim}) must be divisible by num_groups ({g.numel()})") + group_dim = total_dim // g.numel() + return g.repeat_interleave(group_dim) + + +def apply_v_skip(y: Tensor, v: Tensor, gate: Tensor, mode: str = "scalar", num_heads: int | None = None, num_kv_heads: int | None = None) -> Tensor: + b, h, t, d = y.shape + hkv = v.shape[1] + if h == hkv: + v_exp = v + else: + group_size = num_heads // num_kv_heads + v_exp = v.unsqueeze(2).expand(b, hkv, group_size, t, d).reshape(b, h, t, d) + if mode == "scalar": + g = torch.sigmoid(gate.to(dtype=y.dtype)).reshape(1, 1, 1, 1) + return y + g * v_exp + if mode == "group": + g = torch.sigmoid(gate.to(dtype=y.dtype)) + g = _expand_group_gates(g, d).view(1, 1, 1, d) + return y + g * v_exp + raise ValueError(f"Unknown V_SKIP_MODE: {mode}") + + +def mix_cross_layer_v(v_cur: Tensor, v_prev: Tensor, gate: Tensor, mode: str = "residual", group_mode: str = "scalar") -> Tensor: + d = v_cur.shape[-1] + if group_mode == "scalar": + g = torch.sigmoid(gate.to(dtype=v_cur.dtype)).reshape(1, 1, 1, 1) + elif group_mode == "group": + g = torch.sigmoid(gate.to(dtype=v_cur.dtype)) + g = _expand_group_gates(g, d).view(1, 1, 1, d) + else: + raise ValueError(f"Unknown cross-layer V group mode: {group_mode}") + if mode == "residual": + return v_cur + g * v_prev + if mode == "blend": + return (1.0 - g) * v_cur + g * v_prev + raise ValueError(f"Unknown CROSS_LAYER_V_MODE: {mode}") + + +def apply_partial_head_sharing(cur: Tensor, shared: Tensor, share_head_count: int) -> Tensor: + h = cur.shape[1] + n = min(max(share_head_count, 0), h) + if n == 0: + return cur + out = cur.clone() + out[:, :n] = shared[:, :n] + return out + + +def apply_xsa_gqa_efficient(y: Tensor, v: Tensor, num_heads: int, num_kv_heads: int, eps: float = 1e-6) -> Tensor: + if num_heads == num_kv_heads: + vn = v / (v.norm(dim=-1, keepdim=True) + eps) + proj = (y * vn).sum(dim=-1, keepdim=True) + return y - proj * vn + group_size = num_heads // num_kv_heads + b, h, t, d = y.shape + yg = y.view(b, num_kv_heads, group_size, t, d) + vn = v / (v.norm(dim=-1, keepdim=True) + eps) + vn = vn.unsqueeze(2) + proj = (yg * vn).sum(dim=-1, keepdim=True) + yg = yg - proj * vn + return yg.view(b, h, t, d) + + +class StructuredGroupSignedBiFPN(nn.Module): + def __init__(self, num_decoder_layers, num_encoder_layers, model_dim, group_count=8, band_width=1, norm_eps=1e-4, init_main=1.0, init_neighbor=0.15, init_far=0.0): + super().__init__() + if model_dim % group_count != 0: + raise ValueError(f"model_dim ({model_dim}) must be divisible by group_count ({group_count})") + self.num_decoder_layers = num_decoder_layers + self.num_encoder_layers = num_encoder_layers + self.model_dim = model_dim + self.group_count = group_count + self.group_dim = model_dim // group_count + self.band_width = band_width + self.norm_eps = norm_eps + w = torch.full((num_decoder_layers, num_encoder_layers, group_count), init_far, dtype=torch.float32) + for d in range(num_decoder_layers): + sym = num_encoder_layers - 1 - d + for e in range(num_encoder_layers): + dist_val = abs(e - sym) + if dist_val == 0: + w[d, e, :] = init_main + elif dist_val <= band_width: + w[d, e, :] = init_neighbor + mask = torch.zeros((num_decoder_layers, num_encoder_layers, 1), dtype=torch.float32) + for d in range(num_decoder_layers): + sym = num_encoder_layers - 1 - d + for e in range(num_encoder_layers): + if abs(e - sym) <= band_width: + mask[d, e, 0] = 1.0 + self.weights = nn.Parameter(w) + self.register_buffer("mask", mask, persistent=True) + + def forward(self, skips: list[Tensor], decoder_idx: int, x_dtype: torch.dtype) -> Tensor: + stacked = torch.stack(skips, dim=0) + enc, b, t, d = stacked.shape + stacked_g = stacked.view(enc, b, t, self.group_count, self.group_dim) + w = self.weights[decoder_idx] * self.mask[decoder_idx] + w = w.to(dtype=x_dtype) + denom = w.abs().sum(dim=0, keepdim=True).clamp_min(self.norm_eps) + w_norm = w / denom + fused = torch.einsum("eg,ebtgd->btgd", w_norm, stacked_g) + return fused.reshape(b, t, d) + + +class PLEModule(nn.Module): + def __init__(self, args: Hyperparameters): + super().__init__() + self.enabled = args.ple_enabled + self.temporal_conv = args.ple_temporal_conv + self.dim = args.ple_dim + self.mode = args.ple_mode + self.num_layers = args.num_layers + self.model_dim = args.model_dim + if not self.enabled: + return + self.token_embed = nn.Embedding(args.vocab_size, self.num_layers * self.dim) + self.ctx_proj = nn.Linear(self.model_dim, self.num_layers * self.dim, bias=False) + self.out_proj = nn.Linear(self.dim, self.model_dim, bias=False) + self.token_scale = nn.Parameter(torch.full((1,), args.ple_token_scale_init, dtype=torch.float32)) + self.ctx_scale = nn.Parameter(torch.full((1,), args.ple_ctx_scale_init, dtype=torch.float32)) + self.resid_scale = nn.Parameter(torch.full((1,), args.ple_resid_scale_init, dtype=torch.float32)) + if self.temporal_conv: + self.temporal = nn.Conv1d(self.dim, self.dim, kernel_size=3, padding=1, groups=self.dim, bias=False) + else: + self.temporal = None + + def build_all(self, token_ids: Tensor, x_embed: Tensor) -> Tensor | None: + if not self.enabled: + return None + tok = self.token_embed(token_ids).view(token_ids.shape[0], token_ids.shape[1], self.num_layers, self.dim) + ctx = self.ctx_proj(x_embed).view(token_ids.shape[0], token_ids.shape[1], self.num_layers, self.dim) + out = self.token_scale.to(tok.dtype) * tok + self.ctx_scale.to(tok.dtype) * ctx + if self.temporal is not None: + b, t, l, d = out.shape + tmp = out.permute(0, 2, 3, 1).reshape(b * l, d, t) + tmp = self.temporal(tmp) + out = tmp.reshape(b, l, d, t).permute(0, 3, 1, 2) + return out + + def apply(self, x: Tensor, ple_all: Tensor | None, layer_idx: int) -> Tensor: + if ple_all is None: + return x + p = ple_all[:, :, layer_idx, :] + p = self.out_proj(p) + return x + self.resid_scale.to(x.dtype) * p.to(x.dtype) + + +class CausalSelfAttention(nn.Module): + def __init__(self, args: Hyperparameters, layer_idx: int, xsa_enabled: bool = False, xsa_eps: float = 1e-6): + super().__init__() + dim = args.model_dim + num_heads = args.num_heads + num_kv_heads = args.num_kv_heads + rope_base = args.rope_base + qk_gain_init = args.qk_gain_init + + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + + self.layer_idx = layer_idx + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + + self.rope_dims = args.rope_dims if args.rope_dims > 0 else self.head_dim + self.xsa_enabled = xsa_enabled + self.xsa_eps = xsa_eps + self.learnable_rope = args.learnable_rope + + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + if self.learnable_rope: + init_logits = torch.full((self.head_dim // 2,), -4.0, dtype=torch.float32) + init_logits[:8] = 4.0 + self.rope_mix_logits = nn.Parameter(init_logits) + + self.v_skip_enabled = args.v_skip_enabled and (layer_idx >= args.num_layers - args.v_skip_last_n_layers) + self.v_skip_mode = args.v_skip_mode + self.v_skip_group_count = args.v_skip_group_count + if self.v_skip_enabled: + if self.v_skip_mode == "scalar": + self.v_skip_gate = nn.Parameter(torch.zeros(1, dtype=torch.float32)) + elif self.v_skip_mode == "group": + self.v_skip_gate = nn.Parameter(torch.zeros(self.v_skip_group_count, dtype=torch.float32)) + else: + raise ValueError(f"Unknown V_SKIP_MODE: {self.v_skip_mode}") + + self.cross_layer_v_enabled = args.cross_layer_v_enabled and (layer_idx >= args.num_layers - args.cross_layer_v_last_n_layers) + self.cross_layer_v_mode = args.cross_layer_v_mode + self.cross_layer_v_group_count = args.cross_layer_v_group_count + if self.cross_layer_v_enabled: + if args.cross_layer_v_group_count <= 1: + self.cross_layer_v_gate = nn.Parameter(torch.zeros(1, dtype=torch.float32)) + self.cross_layer_v_gate_mode = "scalar" + else: + self.cross_layer_v_gate = nn.Parameter(torch.zeros(args.cross_layer_v_group_count, dtype=torch.float32)) + self.cross_layer_v_gate_mode = "group" + + self.cross_layer_kv_sharing_enabled = args.cross_layer_kv_sharing_enabled and (layer_idx >= args.num_layers - args.cross_layer_kv_last_n_layers) + self.cross_layer_kv_share_k = args.cross_layer_kv_share_k + self.cross_layer_kv_share_v = args.cross_layer_kv_share_v + self.cross_layer_kv_pairwise = args.cross_layer_kv_pairwise + self.cross_layer_kv_partial_head = args.cross_layer_kv_partial_head + self.cross_layer_kv_partial_head_count = args.cross_layer_kv_partial_head_count + + self.value_residual_enabled = args.value_residual_enabled and (layer_idx >= args.num_layers - args.value_residual_last_n_layers) + if self.value_residual_enabled: + self.vr_lambda = nn.Parameter(torch.tensor([args.value_residual_init_v0, args.value_residual_init_cur], dtype=torch.float32)) + + def forward(self, x: Tensor, shared_k: Tensor | None = None, shared_v: Tensor | None = None, prev_v: Tensor | None = None, v0: Tensor | None = None): + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + raw_v = v + + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + + cos, sin = self.rotary(seqlen, x.device, q.dtype) + if self.learnable_rope: + q_rot = apply_rotary_emb(q, cos, sin, rope_dims=0) + k_rot = apply_rotary_emb(k, cos, sin, rope_dims=0) + gamma = torch.sigmoid(self.rope_mix_logits.to(q.dtype)) + gamma = gamma.unsqueeze(-1).expand(-1, 2).reshape(-1) + q = gamma * q_rot + (1 - gamma) * q + k = gamma * k_rot + (1 - gamma) * k + else: + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + + k_eff = k + v_eff = v + + if self.cross_layer_kv_sharing_enabled: + if self.cross_layer_kv_share_k and shared_k is not None: + if self.cross_layer_kv_partial_head: + k_eff = apply_partial_head_sharing(k_eff, shared_k, self.cross_layer_kv_partial_head_count) + else: + k_eff = shared_k + if self.cross_layer_kv_share_v and shared_v is not None: + if self.cross_layer_kv_partial_head: + v_eff = apply_partial_head_sharing(v_eff, shared_v, self.cross_layer_kv_partial_head_count) + else: + v_eff = shared_v + + if self.cross_layer_v_enabled and prev_v is not None: + v_eff = mix_cross_layer_v(v_eff, prev_v, self.cross_layer_v_gate, mode=self.cross_layer_v_mode, group_mode=self.cross_layer_v_gate_mode) + + if self.value_residual_enabled and v0 is not None: + lam = self.vr_lambda.to(dtype=v_eff.dtype) + v_eff = lam[0] * v0 + lam[1] * v_eff + + y = F.scaled_dot_product_attention( + q, k_eff, v_eff, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + + if self.xsa_enabled: + y = apply_xsa_gqa_efficient(y=y, v=v_eff, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, eps=self.xsa_eps) + + if self.v_skip_enabled: + y = apply_v_skip(y=y, v=v_eff, gate=self.v_skip_gate, mode=self.v_skip_mode, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads) + + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + out = self.proj(y) + return out, k_eff, v_eff, raw_v + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class GatedLinearLane(nn.Module): + def __init__(self, dim: int, mult: int = 2, zero_init: bool = True): + super().__init__() + self.hidden_dim = max(1, int(mult)) * dim + self.up_proj = CastedLinear(dim, self.hidden_dim, bias=False) + self.gate_proj = CastedLinear(dim, self.hidden_dim, bias=False) + self.down_proj = CastedLinear(self.hidden_dim, dim, bias=False) + if zero_init: + self.down_proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + return self.down_proj(self.up_proj(x) * F.silu(self.gate_proj(x))) + + +class ConvGateLane(nn.Module): + def __init__(self, dim: int, kernel_size: int = 4): + super().__init__() + self.kernel_size = max(1, int(kernel_size)) + self.dwconv = nn.Conv1d(dim, dim, kernel_size=self.kernel_size, groups=dim, bias=False) + self.gate_proj = CastedLinear(dim, dim, bias=False) + self.pointwise = CastedLinear(dim, dim, bias=False) + self.pointwise._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + y = x.transpose(1, 2) + if self.kernel_size > 1: + y = F.pad(y, (self.kernel_size - 1, 0)) + y = self.dwconv(y).transpose(1, 2) + return self.pointwise(y * torch.sigmoid(self.gate_proj(x))) + + +class TinySSMLane(nn.Module): + def __init__(self, dim: int, state_dim: int = 8, expand: int = 2, conv_kernel: int = 8, gate: bool = True): + super().__init__() + self.inner_dim = max(1, int(expand)) * dim + self.state_dim = max(1, int(state_dim)) + self.conv_kernel = max(1, int(conv_kernel)) + self.use_gate = bool(gate) + + self.in_proj = CastedLinear(dim, self.inner_dim, bias=False) + self.out_proj = CastedLinear(self.inner_dim, dim, bias=False) + self.out_proj._zero_init = True + self.gate_proj = CastedLinear(dim, self.inner_dim, bias=False) if self.use_gate else None + + self.ssm_A_log = nn.Parameter(torch.zeros(self.inner_dim, self.state_dim, dtype=torch.float32)) + self.ssm_B = nn.Parameter(torch.randn(self.inner_dim, self.state_dim, dtype=torch.float32) * 0.02) + self.ssm_C = nn.Parameter(torch.randn(self.inner_dim, self.state_dim, dtype=torch.float32) * 0.02) + self.register_buffer("ssm_t", torch.arange(self.conv_kernel, dtype=torch.float32), persistent=False) + + def _causal_kernel(self, dtype: torch.dtype, device: torch.device) -> Tensor: + t = self.ssm_t.to(device=device, dtype=torch.float32) + decay = torch.exp(-F.softplus(self.ssm_A_log.float()).unsqueeze(-1) * t.view(1, 1, -1)) + weights = (self.ssm_B.float() * self.ssm_C.float()).unsqueeze(-1) * decay + kernel = weights.sum(dim=1).to(dtype=dtype).flip(-1) + return kernel.unsqueeze(1) + + def forward(self, x: Tensor) -> Tensor: + u = self.in_proj(x) + y = u.transpose(1, 2) + if self.conv_kernel > 1: + y = F.pad(y, (self.conv_kernel - 1, 0)) + y = F.conv1d(y, self._causal_kernel(dtype=u.dtype, device=u.device), groups=self.inner_dim).transpose(1, 2) + if self.gate_proj is not None: + y = y * torch.sigmoid(self.gate_proj(x)) + return self.out_proj(F.silu(y)) + + +def count_trainable_params(module: nn.Module | None) -> int: + if module is None: + return 0 + return sum(p.numel() for p in module.parameters() if p.requires_grad) + + +class Block(nn.Module): + def __init__(self, args: Hyperparameters, layer_idx=0, xsa_enabled=False, xsa_eps=1e-6): + super().__init__() + dim = args.model_dim + self.layer_idx = layer_idx + self.ple_mode = args.ple_mode + + self.attn_norm = RMSNorm(dim, affine=args.affine_norm) + self.mlp_norm = RMSNorm(dim, affine=args.affine_norm) + self.attn = CausalSelfAttention(args, layer_idx=layer_idx, xsa_enabled=xsa_enabled, xsa_eps=xsa_eps) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + self.parallel_residual_enabled = ( + args.parallel_residual_enabled and + (layer_idx >= args.num_layers - args.parallel_residual_last_n_layers) + ) + self.parallel_residual_mode = args.parallel_residual_mode + + self.parallel_v2_enabled = ( + args.parallel_v2_enabled and + (layer_idx >= args.num_layers - args.parallel_v2_last_n_layers) + ) + self.parallel_v2_mode = args.parallel_v2_mode + self.parallel_v2_second_lane_name = args.parallel_v2_second_lane + self.parallel_v2_norm_shared = args.parallel_v2_norm_shared + self.parallel_v2_use_post_attn_ple = args.parallel_v2_use_post_attn_ple + self.parallel_v2_delayed_merge = self.parallel_v2_enabled and self.parallel_v2_mode == "delayed_merge" + self.parallel_v2_capture_norm_ratios = False + + if self.parallel_residual_enabled and self.parallel_v2_enabled: + raise ValueError("PARALLEL_RESIDUAL_ENABLED and PARALLEL_V2_ENABLED cannot both apply to the same layer") + + needs_mlp = (not self.parallel_v2_enabled) or self.parallel_residual_enabled or self.parallel_v2_second_lane_name == "mlp" + self.mlp = MLP(dim, args.mlp_mult) if needs_mlp else None + + self.second_lane: nn.Module | None = None + if self.parallel_v2_enabled and self.parallel_v2_second_lane_name != "mlp": + if self.parallel_v2_second_lane_name == "gated_linear": + self.second_lane = GatedLinearLane(dim, mult=args.gated_linear_mult, zero_init=args.gated_linear_zero_init) + elif self.parallel_v2_second_lane_name == "conv_gate": + self.second_lane = ConvGateLane(dim, kernel_size=args.conv_gate_kernel_size) + elif self.parallel_v2_second_lane_name == "ssm": + self.second_lane = TinySSMLane( + dim, + state_dim=args.ssm_state_dim, + expand=args.ssm_expand, + conv_kernel=args.ssm_conv_kernel, + gate=args.ssm_gate, + ) + else: + raise ValueError(f"Unknown PARALLEL_V2_SECOND_LANE: {self.parallel_v2_second_lane_name}") + + self.learnable_ln_scale = args.learnable_ln_scale + init_scale = 1.0 / (math.sqrt(layer_idx + 1) + 0.1 * layer_idx) if args.ln_scale else 1.0 + if self.learnable_ln_scale: + self.layer_scale = nn.Parameter(torch.tensor([init_scale], dtype=torch.float32)) + else: + self.layer_scale = init_scale + + if self.parallel_residual_enabled: + self.parallel_attn_scale = nn.Parameter( + torch.full((dim,), args.parallel_residual_init_attn, dtype=torch.float32) + ) + self.parallel_mlp_scale = nn.Parameter( + torch.full((dim,), args.parallel_residual_init_mlp, dtype=torch.float32) + ) + if self.parallel_residual_mode == "gated_add": + self.parallel_gate = nn.Parameter( + torch.full((dim,), args.parallel_residual_gate_init, dtype=torch.float32) + ) + + if self.parallel_v2_enabled: + self.attn_scale.requires_grad_(False) + self.mlp_scale.requires_grad_(False) + if self.parallel_v2_norm_shared and self.mlp_norm.weight is not None: + self.mlp_norm.weight.requires_grad_(False) + self.parallel_v2_attn_scale = nn.Parameter( + torch.full((dim,), args.parallel_v2_init_attn, dtype=torch.float32) + ) + self.parallel_v2_second_scale = nn.Parameter( + torch.full((dim,), args.parallel_v2_init_second, dtype=torch.float32) + ) + if self.parallel_v2_mode == "gated_add": + self.parallel_v2_gate = nn.Parameter( + torch.full((dim,), args.parallel_v2_gate_init, dtype=torch.float32) + ) + self.register_buffer("parallel_v2_attn_norm_ratio", torch.tensor(float("nan"), dtype=torch.float32), persistent=False) + self.register_buffer("parallel_v2_second_norm_ratio", torch.tensor(float("nan"), dtype=torch.float32), persistent=False) + + def _apply_ple(self, x: Tensor, ple_all: Tensor | None, ple_apply, mode: str) -> Tensor: + if ple_apply is not None and self.ple_mode == mode: + return ple_apply(x, ple_all, self.layer_idx) + return x + + def _parallel_v2_second_lane(self, x: Tensor) -> Tensor: + if self.parallel_v2_second_lane_name == "mlp": + if self.mlp is None: + raise RuntimeError("MLP lane was not constructed") + return self.mlp(x) + if self.second_lane is None: + raise RuntimeError(f"Second lane was not constructed: {self.parallel_v2_second_lane_name}") + return self.second_lane(x) + + def _parallel_v2_merge(self, x_base: Tensor, attn_out: Tensor, second_out: Tensor) -> Tensor: + attn_delta, second_delta = self._parallel_v2_scaled_lanes(attn_out, second_out, x_base.dtype) + if self.parallel_v2_mode == "dual_add" or self.parallel_v2_mode == "delayed_merge": + return x_base + attn_delta + second_delta + if self.parallel_v2_mode == "gated_add": + gate = torch.sigmoid(self.parallel_v2_gate.to(dtype=x_base.dtype))[None, None, :] + return x_base + gate * attn_delta + (1.0 - gate) * second_delta + raise ValueError(f"Unknown PARALLEL_V2_MODE: {self.parallel_v2_mode}") + + def _parallel_v2_scaled_lanes(self, attn_out: Tensor, second_out: Tensor, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + attn_scale = self.parallel_v2_attn_scale.to(dtype=dtype)[None, None, :] + second_scale = self.parallel_v2_second_scale.to(dtype=dtype)[None, None, :] + return attn_scale * attn_out, second_scale * second_out + + def _record_parallel_v2_norm_ratios(self, x_pre: Tensor, attn_out: Tensor, second_out: Tensor) -> None: + if self.parallel_v2_capture_norm_ratios: + denom = x_pre.detach().float().norm().clamp_min(1e-8) + self.parallel_v2_attn_norm_ratio.copy_(attn_out.detach().float().norm() / denom) + self.parallel_v2_second_norm_ratio.copy_(second_out.detach().float().norm() / denom) + + + def forward( + self, + x: Tensor, + x0: Tensor, + ple_all: Tensor | None = None, + ple_apply=None, + shared_k: Tensor | None = None, + shared_v: Tensor | None = None, + prev_v: Tensor | None = None, + v0: Tensor | None = None, + ): + # -------------------------------------------------- + # 1) residual pre-mix + # -------------------------------------------------- + mix = self.resid_mix.to(dtype=x.dtype) + x_base = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # -------------------------------------------------- + # 2) optional PLE before attention / branch split + # -------------------------------------------------- + x_pre = self._apply_ple(x_base, ple_all, ple_apply, "pre_attn") + + # -------------------------------------------------- + # 3) shared scale + # -------------------------------------------------- + scale = self.layer_scale.to(dtype=x.dtype) if isinstance(self.layer_scale, nn.Parameter) else self.layer_scale + + # -------------------------------------------------- + # 4) attention branch + # -------------------------------------------------- + attn_in = self.attn_norm(x_pre) * scale + attn_out, k_eff, v_eff, raw_v = self.attn( + attn_in, + shared_k=shared_k, + shared_v=shared_v, + prev_v=prev_v, + v0=v0, + ) + + # -------------------------------------------------- + # 5) serial vs parallel residual update + # -------------------------------------------------- + if self.parallel_v2_enabled: + if ple_apply is not None and self.ple_mode == "post_attn" and not self.parallel_v2_use_post_attn_ple: + raise ValueError("PLE mode 'post_attn' is not supported when PARALLEL_V2_ENABLED=1") + + if self.parallel_v2_norm_shared: + second_in = attn_in + else: + second_in = self.mlp_norm(x_pre) * scale + second_out = self._parallel_v2_second_lane(second_in) + self._record_parallel_v2_norm_ratios(x_pre, attn_out, second_out) + + if self.parallel_v2_delayed_merge: + return x_pre, k_eff, v_eff, raw_v, attn_out, second_out + + x_out = self._parallel_v2_merge(x_pre, attn_out, second_out) + if self.parallel_v2_use_post_attn_ple: + x_out = self._apply_ple(x_out, ple_all, ple_apply, "post_attn") + + elif self.parallel_residual_enabled: + # In parallel mode, attention and MLP both see the same pre-branch state. + # This is the key difference from the serial Transformer block. + if self.mlp is None: + raise RuntimeError("MLP lane was not constructed") + mlp_in = self.mlp_norm(x_pre) * scale + mlp_out = self.mlp(mlp_in) + + attn_scale = self.parallel_attn_scale.to(dtype=x.dtype)[None, None, :] + mlp_scale = self.parallel_mlp_scale.to(dtype=x.dtype)[None, None, :] + + if self.parallel_residual_mode == "dual_add": + x_out = x_pre + attn_scale * attn_out + mlp_scale * mlp_out + elif self.parallel_residual_mode == "gated_add": + gate = torch.sigmoid(self.parallel_gate.to(dtype=x.dtype))[None, None, :] + x_out = x_pre + gate * (attn_scale * attn_out) + (1.0 - gate) * (mlp_scale * mlp_out) + else: + raise ValueError(f"Unknown PARALLEL_RESIDUAL_MODE: {self.parallel_residual_mode}") + + # NOTE: + # In parallel mode, "post_attn" becomes ambiguous because there is no single + # canonical state that is "after attention but before MLP" anymore. + # To keep semantics clean, we do not support post_attn PLE in parallel mode. + if ple_apply is not None and self.ple_mode == "post_attn": + raise ValueError("PLE mode 'post_attn' is not supported when PARALLEL_RESIDUAL_ENABLED=1") + + else: + # Serial Transformer block: + # MLP consumes the post-attention hidden state. + if self.mlp is None: + raise RuntimeError("MLP lane was not constructed") + x_after_attn = x_pre + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + + x_after_attn = self._apply_ple(x_after_attn, ple_all, ple_apply, "post_attn") + + mlp_in = self.mlp_norm(x_after_attn) * scale + mlp_out = self.mlp(mlp_in) + + x_out = x_after_attn + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * mlp_out + + # -------------------------------------------------- + # 6) optional PLE after FFN / final merge + # -------------------------------------------------- + x_out = self._apply_ple(x_out, ple_all, ple_apply, "post_ffn") + + return x_out, k_eff, v_eff, raw_v + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class CausalLocalMixing(nn.Module): + def __init__(self, dim: int, window_size: int = 4): + super().__init__() + self.window_size = window_size + self.dim = dim + w = torch.zeros(window_size, dim, dtype=torch.float32) + w[0, :] = 3.0 + self.mix_logits = nn.Parameter(w) + + def forward(self, x: Tensor) -> Tensor: + if self.window_size <= 1: + return x + w_soft = F.softmax(self.mix_logits.to(x.dtype), dim=0) + kernel = w_soft.flip(0).T.unsqueeze(1) + x_t = x.transpose(1, 2) + x_padded = F.pad(x_t, (self.window_size - 1, 0)) + out = F.conv1d(x_padded, kernel, groups=self.dim) + return out.transpose(1, 2) + + +class NGramHashEmbedding(nn.Module): + def __init__(self, vocab_size: int, dim: int, model_dim: int, max_n: int = 4): + super().__init__() + self.max_n = max_n + self.vocab_size = vocab_size + self.embeds = nn.ModuleList([nn.Embedding(vocab_size, dim) for _ in range(2, max_n + 1)]) + for emb in self.embeds: + nn.init.normal_(emb.weight, std=0.01) + self.proj = nn.Linear(dim, model_dim, bias=False) if dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.ngram_scales = nn.Parameter(torch.full((max_n - 1,), 0.05, dtype=torch.float32)) + + def ngram_hash(self, tokens: Tensor, n: int) -> Tensor: + t = tokens.to(torch.int64) + mod = self.vocab_size - 1 + out = torch.empty_like(t) + out[..., :n - 1] = mod + primes = [36313, 27191, 19393, 13127, 9767] + hash_val = t[..., n - 1:] * primes[0] + for i in range(1, n): + hash_val = torch.bitwise_xor(hash_val, t[..., n - 1 - i: -i] * primes[i]) + out[..., n - 1:] = hash_val % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + fused_h = None + for idx, n in enumerate(range(2, self.max_n + 1)): + h_n = self.embeds[idx](self.ngram_hash(token_ids, n)) + scaled_h = h_n * self.ngram_scales[idx].to(dtype=h_n.dtype) + fused_h = scaled_h if fused_h is None else fused_h + scaled_h + if self.proj is not None: + fused_h = self.proj(fused_h) + return fused_h + + +def compute_ngram_fade_scale(step, total_steps, enabled, start_frac, end_frac, min_scale=0.0) -> float: + if not enabled: + return 1.0 + if total_steps <= 0: + return 1.0 + p = step / float(total_steps) + start_frac = max(0.0, min(1.0, start_frac)) + end_frac = max(start_frac + 1e-8, min(1.0, end_frac)) + min_scale = max(0.0, min(1.0, min_scale)) + if p <= start_frac: + return 1.0 + if p >= end_frac: + return min_scale + alpha = (p - start_frac) / (end_frac - start_frac) + return (1.0 - alpha) + alpha * min_scale + + +class GPT(nn.Module): + def __init__(self, args: Hyperparameters, master_process: bool = True): + super().__init__() + self.args = args + self.fda_mode = args.fda_mode + self.skip_distance = 2 + self.num_layers = args.num_layers + self.cross_layer_kv_sharing_enabled = args.cross_layer_kv_sharing_enabled + self.cross_layer_kv_last_n_layers = args.cross_layer_kv_last_n_layers + self.cross_layer_kv_pairwise = args.cross_layer_kv_pairwise + self.tie_embeddings = args.tie_embeddings + self.tied_embed_init_std = args.tied_embed_init_std + self.logit_softcap = args.logit_softcap + self.scaledlm_head = args.scaledlm_head + self.mtphead_mlpmode = args.mtphead_mlpmode + self.depth_recur_enabled = args.depth_recur_enabled + self.num_stem_blocks = args.num_stem_blocks + self.num_core_blocks = args.num_core_blocks + self.num_core_repeats = args.num_core_repeats + self.num_tail_blocks = args.num_stem_blocks + args.parallel_v2_mode = args.parallel_v2_mode.strip().lower() + args.parallel_v2_second_lane = args.parallel_v2_second_lane.strip().lower() + self.parallel_v2_enabled = args.parallel_v2_enabled + self.parallel_v2_mode = args.parallel_v2_mode + self.parallel_v2_has_delayed_merge = ( + args.parallel_v2_enabled and + args.parallel_v2_mode == "delayed_merge" and + args.parallel_v2_last_n_layers > 0 + ) + if args.parallel_v2_enabled: + if args.parallel_residual_enabled: + raise ValueError("PARALLEL_RESIDUAL_ENABLED and PARALLEL_V2_ENABLED are mutually exclusive") + if args.parallel_v2_mode not in {"dual_add", "gated_add", "delayed_merge"}: + raise ValueError(f"Unknown PARALLEL_V2_MODE: {args.parallel_v2_mode}") + if args.parallel_v2_second_lane not in {"mlp", "gated_linear", "conv_gate", "ssm"}: + raise ValueError(f"Unknown PARALLEL_V2_SECOND_LANE: {args.parallel_v2_second_lane}") + if args.parallel_v2_last_n_layers < 0: + raise ValueError("PARALLEL_V2_LAST_N_LAYERS must be >= 0") + if args.parallel_v2_mode == "delayed_merge" and args.parallel_v2_delayed_merge_steps != 1: + raise ValueError("PARALLEL_V2_DELAYED_MERGE_STEPS currently only supports 1") + if args.ple_enabled and args.ple_mode == "post_attn" and not args.parallel_v2_use_post_attn_ple: + raise ValueError("Set PARALLEL_V2_USE_POST_ATTN_PLE=1 to opt into post-merge PLE semantics") + + model_dim = args.model_dim + num_layers = args.num_layers + self.tok_emb = nn.Embedding(args.vocab_size, model_dim) + + self.ple = PLEModule(args) + + self.smear_mode = args.smear_mode + if self.smear_mode: + self.local_mix = CausalLocalMixing(model_dim, window_size=args.smear_window) + if master_process: + print(f"Architecture: Local Causal Mixing (Window={args.smear_window})") + self.smear_gate = args.smear_gate + if self.smear_gate: + self.smear_gate_module = SmearGate(model_dim) + if master_process: + print("Architecture: SmearGate (1-step causal blend)") + + self.ngram_max_n = args.ngram_max_n + if args.ngram_vocab_size > 0 and self.ngram_max_n >= 2: + self.ngram = NGramHashEmbedding(args.ngram_vocab_size, args.ngram_dim, model_dim, max_n=self.ngram_max_n) + if master_process: + print(f"Architecture: Discrete N-Gram Hash (Max N={self.ngram_max_n})") + else: + self.ngram = None + self.register_buffer("ngram_global_scale_buf", torch.tensor(1.0, dtype=torch.float32), persistent=False) + + self.blocks = nn.ModuleList([ + Block( + args, + layer_idx=i, + xsa_enabled=(args.xsa_enabled and i >= num_layers - args.xsa_last_n_layers), + xsa_eps=args.xsa_eps, + ) for i in range(num_layers) + ]) + + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.bifpn_mode = args.bifpn_mode + self.bifpn2_mode = args.bifpn2_mode + + if self.depth_recur_enabled: + self.skip_weights = nn.Parameter(torch.ones(self.num_stem_blocks, model_dim, dtype=torch.float32)) + elif self.bifpn_mode: + w = torch.full((self.num_decoder_layers, self.num_encoder_layers), 0.1, dtype=torch.float32) + for i in range(self.num_decoder_layers): + sym_idx = self.num_encoder_layers - 1 - i + if sym_idx >= 0: + w[i, sym_idx] = 1.0 + self.bifpn_weights = nn.Parameter(w) + elif self.bifpn2_mode: + self.structured_bifpn = StructuredGroupSignedBiFPN( + num_decoder_layers=self.num_decoder_layers, + num_encoder_layers=self.num_encoder_layers, + model_dim=model_dim, + group_count=args.bifpn_group_count, + band_width=args.bifpn_band_width, + norm_eps=args.bifpn_norm_eps, + init_main=args.bifpn_init_main, + init_neighbor=args.bifpn_init_neighbor, + init_far=args.bifpn_init_far, + ) + elif self.fda_mode: + num_conn = max(0, num_layers - self.skip_distance) + self.skip_weights = nn.Parameter(torch.ones(num_conn, model_dim, dtype=torch.float32)) + else: + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + + self.final_norm = RMSNorm(model_dim, affine=args.affine_norm) + self.lm_head = None if args.tie_embeddings else CastedLinear(model_dim, args.vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + + self.mtp_num_heads = args.mtp_num_heads + self.mtp_loss_weight = args.mtp_loss_weight + if self.mtp_num_heads > 0: + if self.mtphead_mlpmode: + self.mtp_heads = nn.ModuleList([ + nn.Sequential( + nn.Linear(model_dim, model_dim * 2, bias=False), + nn.GELU(), + nn.Linear(model_dim * 2, args.vocab_size, bias=False), + ) for _ in range(self.mtp_num_heads) + ]) + else: + self.mtp_heads = nn.ModuleList([CastedLinear(model_dim, args.vocab_size, bias=False) for _ in range(self.mtp_num_heads)]) + else: + self.mtp_heads = nn.ModuleList([]) + + self.max_logit_pre_cap = 0.0 + self._init_weights() + + def _init_weights(self): + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def set_parallel_v2_norm_capture(self, enabled: bool) -> None: + for block in self.blocks: + if getattr(block, "parallel_v2_enabled", False): + block.parallel_v2_capture_norm_ratios = enabled + if enabled: + block.parallel_v2_attn_norm_ratio.fill_(float("nan")) + block.parallel_v2_second_norm_ratio.fill_(float("nan")) + + def _should_share_from_prev_layer(self, block_idx: int) -> bool: + if not self.cross_layer_kv_sharing_enabled: + return False + return block_idx >= self.num_layers - self.cross_layer_kv_last_n_layers + + def _merge_parallel_v2_pending( + self, + x: Tensor, + ple_all: Tensor | None, + ple_apply, + pending_layer_idx: int, + pending_attn_delta: Tensor | None, + pending_second_delta: Tensor | None, + ) -> tuple[Tensor, int, Tensor | None, Tensor | None]: + if self.parallel_v2_has_delayed_merge and pending_attn_delta is not None: + if pending_second_delta is None: + raise RuntimeError("Missing delayed Parallel v2 second-lane tensor") + x = x + pending_attn_delta + pending_second_delta + if ple_apply is not None and self.args.ple_mode == "post_attn" and self.args.parallel_v2_use_post_attn_ple: + x = ple_apply(x, ple_all, pending_layer_idx) + if ple_apply is not None and self.args.ple_mode == "post_ffn": + x = ple_apply(x, ple_all, pending_layer_idx) + return x, -1, None, None + return x, pending_layer_idx, pending_attn_delta, pending_second_delta + + def _forward_hidden(self, input_ids: Tensor) -> Tensor: + last_v_for_cross_layer_v: Tensor | None = None + last_k_for_kv_sharing: Tensor | None = None + last_v_for_kv_sharing: Tensor | None = None + v0_global: Tensor | None = None + pending_parallel_v2_layer_idx = -1 + pending_parallel_v2_attn_delta: Tensor | None = None + pending_parallel_v2_second_delta: Tensor | None = None + + x = self.tok_emb(input_ids) + ple_all = self.ple.build_all(input_ids, x) + ple_apply = self.ple.apply if self.ple.enabled else None + + if getattr(self, "ngram", None) is not None: + scale = self.ngram_global_scale_buf.to(dtype=x.dtype) + x = x + scale * self.ngram(input_ids) + + x = F.rms_norm(x, (x.size(-1),)) + if self.smear_mode: + x = self.local_mix(x) + if self.smear_gate: + x = self.smear_gate_module(x) + x0 = x + + if self.depth_recur_enabled: + stem_skips: list[Tensor] = [] + for i in range(self.num_stem_blocks): + shared_k = last_k_for_kv_sharing if self._should_share_from_prev_layer(i) else None + shared_v = last_v_for_kv_sharing if self._should_share_from_prev_layer(i) else None + x, k_eff, v_eff, raw_v, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._run_block( + self.blocks[i], i, x, x0, ple_all, ple_apply, shared_k, shared_v, + last_v_for_cross_layer_v, v0_global, + pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta, + ) + if v0_global is None: + v0_global = raw_v + last_v_for_cross_layer_v = v_eff + last_k_for_kv_sharing = k_eff + last_v_for_kv_sharing = v_eff + stem_skips.append(x) + + core_start = self.num_stem_blocks + for _ in range(self.num_core_repeats): + for j in range(self.num_core_blocks): + block_idx = core_start + j + shared_k = last_k_for_kv_sharing if self._should_share_from_prev_layer(block_idx) else None + shared_v = last_v_for_kv_sharing if self._should_share_from_prev_layer(block_idx) else None + x, k_eff, v_eff, raw_v, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._run_block( + self.blocks[block_idx], block_idx, x, x0, ple_all, ple_apply, shared_k, shared_v, + last_v_for_cross_layer_v, v0_global, + pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta, + ) + if v0_global is None: + v0_global = raw_v + last_v_for_cross_layer_v = v_eff + last_k_for_kv_sharing = k_eff + last_v_for_kv_sharing = v_eff + + tail_start = self.num_stem_blocks + self.num_core_blocks + for i in range(self.num_tail_blocks): + skip_x = stem_skips[self.num_stem_blocks - 1 - i] + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skip_x + block_idx = tail_start + i + shared_k = last_k_for_kv_sharing if self._should_share_from_prev_layer(block_idx) else None + shared_v = last_v_for_kv_sharing if self._should_share_from_prev_layer(block_idx) else None + x, k_eff, v_eff, raw_v, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._run_block( + self.blocks[block_idx], block_idx, x, x0, ple_all, ple_apply, shared_k, shared_v, + last_v_for_cross_layer_v, v0_global, + pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta, + ) + if v0_global is None: + v0_global = raw_v + last_v_for_cross_layer_v = v_eff + last_k_for_kv_sharing = k_eff + last_v_for_kv_sharing = v_eff + + elif self.bifpn2_mode: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + shared_k = last_k_for_kv_sharing if self._should_share_from_prev_layer(i) else None + shared_v = last_v_for_kv_sharing if self._should_share_from_prev_layer(i) else None + x, k_eff, v_eff, raw_v, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._run_block( + self.blocks[i], i, x, x0, ple_all, ple_apply, shared_k, shared_v, + last_v_for_cross_layer_v, v0_global, + pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta, + ) + if v0_global is None: + v0_global = raw_v + skips.append(x) + last_v_for_cross_layer_v = v_eff + last_k_for_kv_sharing = k_eff + last_v_for_kv_sharing = v_eff + + for i in range(self.num_decoder_layers): + fusion_feature = self.structured_bifpn(skips=skips, decoder_idx=i, x_dtype=x.dtype) + x = x + fusion_feature + block_idx = self.num_encoder_layers + i + shared_k = last_k_for_kv_sharing if self._should_share_from_prev_layer(block_idx) else None + shared_v = last_v_for_kv_sharing if self._should_share_from_prev_layer(block_idx) else None + x, k_eff, v_eff, raw_v, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._run_block( + self.blocks[block_idx], block_idx, x, x0, ple_all, ple_apply, shared_k, shared_v, + last_v_for_cross_layer_v, v0_global, + pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta, + ) + if v0_global is None: + v0_global = raw_v + last_v_for_cross_layer_v = v_eff + last_k_for_kv_sharing = k_eff + last_v_for_kv_sharing = v_eff + + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x, k_eff, v_eff, raw_v, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._run_block( + self.blocks[i], i, x, x0, ple_all, ple_apply, None, None, None, v0_global, + pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta, + ) + if v0_global is None: + v0_global = raw_v + skips.append(x) + + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + block_idx = self.num_encoder_layers + i + x, k_eff, v_eff, raw_v, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._run_block( + self.blocks[block_idx], block_idx, x, x0, ple_all, ple_apply, None, None, None, v0_global, + pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta, + ) + if v0_global is None: + v0_global = raw_v + + x, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._merge_parallel_v2_pending( + x, + ple_all, + ple_apply, + pending_parallel_v2_layer_idx, + pending_parallel_v2_attn_delta, + pending_parallel_v2_second_delta, + ) + x = self.final_norm(x) + return x + + def _project_logits_from_hidden(self, x: Tensor) -> Tensor: + B, T, D = x.shape + x_flat = x.reshape(-1, D) + + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + if self.scaledlm_head: + logits_proj = logits_proj / math.sqrt(D) + else: + logits_proj = self.lm_head(x_flat) + if self.scaledlm_head: + logits_proj = logits_proj / math.sqrt(D) + + if not self.training or getattr(self, "_log_logits", False): + self.max_logit_pre_cap = logits_proj.detach().abs().max() + + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return logits.view(B, T, -1) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self._forward_hidden(input_ids) + return self._project_logits_from_hidden(x) + + def _run_block( + self, + block: Block, + block_idx: int, + x: Tensor, + x0: Tensor, + ple_all: Tensor | None, + ple_apply, + shared_k: Tensor | None, + shared_v: Tensor | None, + prev_v: Tensor | None, + v0: Tensor | None, + pending_layer_idx: int, + pending_attn_delta: Tensor | None, + pending_second_delta: Tensor | None, + ): + x, pending_layer_idx, pending_attn_delta, pending_second_delta = self._merge_parallel_v2_pending( + x, + ple_all, + ple_apply, + pending_layer_idx, + pending_attn_delta, + pending_second_delta, + ) + out = block( + x, + x0, + ple_all=ple_all, + ple_apply=ple_apply, + shared_k=shared_k, + shared_v=shared_v, + prev_v=prev_v, + v0=v0, + ) + if self.parallel_v2_has_delayed_merge and block.parallel_v2_delayed_merge: + x, k_eff, v_eff, raw_v, attn_out, second_out = out + attn_delta, second_delta = block._parallel_v2_scaled_lanes(attn_out, second_out, x.dtype) + return x, k_eff, v_eff, raw_v, block_idx, attn_delta, second_delta + x, k_eff, v_eff, raw_v = out + return x, k_eff, v_eff, raw_v, pending_layer_idx, pending_attn_delta, pending_second_delta + + def forward(self, input_ids: Tensor, target_ids: Tensor, reduction: str = "mean", ngram_global_scale: float = 1.0): + x = self._forward_hidden(input_ids) + x_original = x + B, T, D = x.shape + + logits = self._project_logits_from_hidden(x) + targets = target_ids.reshape(-1) + + if reduction == "none": + loss_flat = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + targets, + reduction="none", + ) + loss_tokens = loss_flat.view(B, T) + return loss_tokens.mean(), loss_tokens + + main_loss = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + targets, + reduction="mean", + ) + + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + mtp_loss_sum = x_original.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = T - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x_original[:, :valid_t, :].reshape(-1, D) + mtp_targets = target_ids[:, k + 1:].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy( + mtp_logits.float(), mtp_targets, reduction="mean" + ) + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + + return main_loss + + # def forward(self, input_ids: Tensor, target_ids: Tensor, reduction: str = "mean", ngram_global_scale: float = 1.0): + # last_v_for_cross_layer_v: Tensor | None = None + # last_k_for_kv_sharing: Tensor | None = None + # last_v_for_kv_sharing: Tensor | None = None + # v0_global: Tensor | None = None + # pending_parallel_v2_layer_idx = -1 + # pending_parallel_v2_attn_delta: Tensor | None = None + # pending_parallel_v2_second_delta: Tensor | None = None + + # x = self.tok_emb(input_ids) + # ple_all = self.ple.build_all(input_ids, x) + # ple_apply = self.ple.apply if self.ple.enabled else None + + # if getattr(self, "ngram", None) is not None: + # scale = self.ngram_global_scale_buf.to(dtype=x.dtype) + # x = x + scale * self.ngram(input_ids) + + # x = F.rms_norm(x, (x.size(-1),)) + # if self.smear_mode: + # x = self.local_mix(x) + # if self.smear_gate: + # x = self.smear_gate_module(x) + # x0 = x + + # if self.depth_recur_enabled: + # stem_skips: list[Tensor] = [] + # for i in range(self.num_stem_blocks): + # shared_k = last_k_for_kv_sharing if self._should_share_from_prev_layer(i) else None + # shared_v = last_v_for_kv_sharing if self._should_share_from_prev_layer(i) else None + # x, k_eff, v_eff, raw_v, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._run_block( + # self.blocks[i], i, x, x0, ple_all, ple_apply, shared_k, shared_v, + # last_v_for_cross_layer_v, v0_global, + # pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta, + # ) + # if v0_global is None: + # v0_global = raw_v + # last_v_for_cross_layer_v = v_eff + # last_k_for_kv_sharing = k_eff + # last_v_for_kv_sharing = v_eff + # stem_skips.append(x) + + # core_start = self.num_stem_blocks + # for _ in range(self.num_core_repeats): + # for j in range(self.num_core_blocks): + # block_idx = core_start + j + # shared_k = last_k_for_kv_sharing if self._should_share_from_prev_layer(block_idx) else None + # shared_v = last_v_for_kv_sharing if self._should_share_from_prev_layer(block_idx) else None + # x, k_eff, v_eff, raw_v, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._run_block( + # self.blocks[block_idx], block_idx, x, x0, ple_all, ple_apply, shared_k, shared_v, + # last_v_for_cross_layer_v, v0_global, + # pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta, + # ) + # if v0_global is None: + # v0_global = raw_v + # last_v_for_cross_layer_v = v_eff + # last_k_for_kv_sharing = k_eff + # last_v_for_kv_sharing = v_eff + + # tail_start = self.num_stem_blocks + self.num_core_blocks + # for i in range(self.num_tail_blocks): + # skip_x = stem_skips[self.num_stem_blocks - 1 - i] + # x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skip_x + # block_idx = tail_start + i + # shared_k = last_k_for_kv_sharing if self._should_share_from_prev_layer(block_idx) else None + # shared_v = last_v_for_kv_sharing if self._should_share_from_prev_layer(block_idx) else None + # x, k_eff, v_eff, raw_v, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._run_block( + # self.blocks[block_idx], block_idx, x, x0, ple_all, ple_apply, shared_k, shared_v, + # last_v_for_cross_layer_v, v0_global, + # pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta, + # ) + # if v0_global is None: + # v0_global = raw_v + # last_v_for_cross_layer_v = v_eff + # last_k_for_kv_sharing = k_eff + # last_v_for_kv_sharing = v_eff + + # elif self.bifpn2_mode: + # skips: list[Tensor] = [] + # for i in range(self.num_encoder_layers): + # shared_k = last_k_for_kv_sharing if self._should_share_from_prev_layer(i) else None + # shared_v = last_v_for_kv_sharing if self._should_share_from_prev_layer(i) else None + # x, k_eff, v_eff, raw_v, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._run_block( + # self.blocks[i], i, x, x0, ple_all, ple_apply, shared_k, shared_v, + # last_v_for_cross_layer_v, v0_global, + # pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta, + # ) + # if v0_global is None: + # v0_global = raw_v + # skips.append(x) + # last_v_for_cross_layer_v = v_eff + # last_k_for_kv_sharing = k_eff + # last_v_for_kv_sharing = v_eff + + # for i in range(self.num_decoder_layers): + # fusion_feature = self.structured_bifpn(skips=skips, decoder_idx=i, x_dtype=x.dtype) + # x = x + fusion_feature + # block_idx = self.num_encoder_layers + i + # shared_k = last_k_for_kv_sharing if self._should_share_from_prev_layer(block_idx) else None + # shared_v = last_v_for_kv_sharing if self._should_share_from_prev_layer(block_idx) else None + # x, k_eff, v_eff, raw_v, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._run_block( + # self.blocks[block_idx], block_idx, x, x0, ple_all, ple_apply, shared_k, shared_v, + # last_v_for_cross_layer_v, v0_global, + # pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta, + # ) + # if v0_global is None: + # v0_global = raw_v + # last_v_for_cross_layer_v = v_eff + # last_k_for_kv_sharing = k_eff + # last_v_for_kv_sharing = v_eff + + # else: + # skips: list[Tensor] = [] + # for i in range(self.num_encoder_layers): + # x, k_eff, v_eff, raw_v, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._run_block( + # self.blocks[i], i, x, x0, ple_all, ple_apply, None, None, None, v0_global, + # pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta, + # ) + # if v0_global is None: + # v0_global = raw_v + # skips.append(x) + # for i in range(self.num_decoder_layers): + # if skips: + # x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + # block_idx = self.num_encoder_layers + i + # x, k_eff, v_eff, raw_v, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._run_block( + # self.blocks[block_idx], block_idx, x, x0, ple_all, ple_apply, None, None, None, v0_global, + # pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta, + # ) + # if v0_global is None: + # v0_global = raw_v + + # x, pending_parallel_v2_layer_idx, pending_parallel_v2_attn_delta, pending_parallel_v2_second_delta = self._merge_parallel_v2_pending( + # x, + # ple_all, + # ple_apply, + # pending_parallel_v2_layer_idx, + # pending_parallel_v2_attn_delta, + # pending_parallel_v2_second_delta, + # ) + # x = self.final_norm(x) + # x_original = x + # x = x.reshape(-1, x.size(-1)) + # targets = target_ids.reshape(-1) + + # if self.tie_embeddings: + # logits_proj = F.linear(x, self.tok_emb.weight) + # if self.scaledlm_head: + # logits_proj = logits_proj / math.sqrt(x.size(-1)) + # else: + # logits_proj = self.lm_head(x) + # if self.scaledlm_head: + # logits_proj = logits_proj / math.sqrt(x.size(-1)) + + # if not self.training or getattr(self, "_log_logits", False): + # self.max_logit_pre_cap = logits_proj.detach().abs().max() + + # logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + # if reduction == "none": + # loss_flat = F.cross_entropy(logits.float(), targets, reduction="none") + # loss_tokens = loss_flat.view(input_ids.shape[0], input_ids.shape[1]) + # return loss_tokens.mean(), loss_tokens + + # main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + + # if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + # _, seqlen, dim = x_original.shape + # mtp_loss_sum = x_original.new_zeros(()) + # mtp_loss_count = 0 + # for k, mtp_head in enumerate(self.mtp_heads): + # valid_t = seqlen - (k + 1) + # if valid_t <= 0: + # continue + # mtp_hidden = x_original[:, :valid_t, :].reshape(-1, dim) + # mtp_targets = target_ids[:, k + 1:].reshape(-1) + # mtp_logits_proj = mtp_head(mtp_hidden) + # mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + # mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + # mtp_loss_count += 1 + # if mtp_loss_count > 0: + # main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + + # return main_loss + + +# ============================================================ +# DATA STREAMING +# ============================================================ + +class TokenStream: + def __init__(self, pattern: str): + self.files = [Path(p) for p in sorted(glob.glob(pattern))] + if not self.files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + self.file_idx = 0 + self.tokens = load_data_shard(self.files[0]) + self.pos = 0 + + def _advance_file(self): + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos: self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int): + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start: start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ============================================================ +# EVAL +# ============================================================ + +@torch.no_grad() +def eval_val(args: Hyperparameters, model: nn.Module, rank: int, world_size: int, device: torch.device, grad_accum_steps: int, val_tokens: Tensor, base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor): + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError("VAL_BATCH_SIZE too small") + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +@torch.no_grad() +def eval_val_sliding(args: Hyperparameters, model: nn.Module, rank: int, world_size: int, device: torch.device, grad_accum_steps: int, val_tokens: Tensor, base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor): + model.eval() + seq_len = args.train_seq_len + stride = args.eval_stride + batch_seqs = args.eval_batch_seqs + + total_loss_sum = torch.zeros(1, device=device, dtype=torch.float64) + total_token_count = torch.zeros(1, device=device, dtype=torch.float64) + total_byte_count = torch.zeros(1, device=device, dtype=torch.float64) + + max_start = val_tokens.numel() - 1 - seq_len + starts = list(range(0, max_start + 1, stride)) + if starts[-1] != max_start: + starts.append(max_start) + starts = starts[rank::world_size] + + def _score_batch(xb: Tensor, yb: Tensor): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + _, loss_tokens = model(xb, yb, reduction="none") + score_loss = loss_tokens[:, -stride:] + token_count = torch.tensor(score_loss.numel(), device=device, dtype=torch.float64) + scored_y = yb[:, -stride:] + scored_x = xb[:, -stride:] + byte_count = tokens_to_bytes_count(scored_x, scored_y, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut).to(torch.float64) + loss_sum = score_loss.sum(dtype=torch.float64) + return loss_sum, token_count, byte_count + + batch_x = [] + batch_y = [] + for start in starts: + chunk = val_tokens[start: start + seq_len + 1].to(device=device, dtype=torch.int64, non_blocking=True) + batch_x.append(chunk[:-1]) + batch_y.append(chunk[1:]) + if len(batch_x) == batch_seqs: + xb = torch.stack(batch_x) + yb = torch.stack(batch_y) + ls, tc, bc = _score_batch(xb, yb) + total_loss_sum += ls + total_token_count += tc + total_byte_count += bc + batch_x.clear() + batch_y.clear() + if batch_x: + xb = torch.stack(batch_x) + yb = torch.stack(batch_y) + ls, tc, bc = _score_batch(xb, yb) + total_loss_sum += ls + total_token_count += tc + total_byte_count += bc + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(total_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(total_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(total_byte_count, op=dist.ReduceOp.SUM) + + val_loss = (total_loss_sum / total_token_count).item() + val_bpb = (total_loss_sum / (math.log(2.0) * total_byte_count)).item() + model.train() + return val_loss, val_bpb + + +class BytePPM: + """ + Lightweight order-k byte PPM-style model for eval-time mixture experiments. + This is intentionally simple and readable, not a hyper-optimized PPM-D clone. + + The first version used tuple contexts and Counter objects. That is convenient, + but it creates millions of small Python objects on multi-million-token evals. + This version packs short byte contexts into integer keys and keeps single-byte + continuations in one integer until a context needs to branch. + """ + + def __init__(self, order: int = 5, max_contexts: int = 0): + self.order = max(0, int(order)) + self.max_contexts = max(0, int(max_contexts)) + self._ctx_shift = 8 * max(1, self.order) + self.ctx_counts: dict[int, int | dict[int, int]] = {} + self.ctx_totals: dict[int, int] = {} + self.skipped_new_contexts = 0 + + @property + def context_count(self) -> int: + return len(self.ctx_counts) + + @staticmethod + def _pack_single(byte: int, count: int) -> int: + return (int(count) << 8) | int(byte) + + @staticmethod + def _unpack_single(entry: int) -> tuple[int, int]: + return entry & 0xFF, entry >> 8 + + def _ctx_key(self, history: bytearray, k: int) -> int: + packed = 0 + if k > 0: + start = len(history) - k + for i in range(start, len(history)): + packed = (packed << 8) | int(history[i]) + return (int(k) << self._ctx_shift) | packed + + def _entry_true_prob_and_conf( + self, + entry: int | dict[int, int], + total: int, + true_byte: int, + ) -> tuple[float, float]: + if total <= 0: + return 0.0, 0.0 + if isinstance(entry, int): + byte, count = self._unpack_single(entry) + p_true = float(count) / float(total) if byte == true_byte else 0.0 + return p_true, float(count) / float(total) + + true_count = entry.get(true_byte, 0) + max_count = max(entry.values()) if entry else 0 + return float(true_count) / float(total), float(max_count) / float(total) + + def predict_true_and_conf(self, history: bytearray, true_byte: int) -> tuple[float, float]: + max_k = min(self.order, len(history)) + for k in range(max_k, -1, -1): + key = self._ctx_key(history, k) + entry = self.ctx_counts.get(key) + if entry is None: + continue + return self._entry_true_prob_and_conf( + entry, + self.ctx_totals.get(key, 0), + int(true_byte), + ) + return 1.0 / 256.0, 1.0 / 256.0 + + def _add_count(self, key: int, next_byte: int) -> None: + entry = self.ctx_counts.get(key) + if entry is None: + if self.max_contexts > 0 and len(self.ctx_counts) >= self.max_contexts and key != 0: + self.skipped_new_contexts += 1 + return + self.ctx_counts[key] = self._pack_single(next_byte, 1) + self.ctx_totals[key] = 1 + return + + self.ctx_totals[key] += 1 + if isinstance(entry, int): + old_byte, old_count = self._unpack_single(entry) + if old_byte == next_byte: + self.ctx_counts[key] = self._pack_single(next_byte, old_count + 1) + else: + self.ctx_counts[key] = {old_byte: old_count, int(next_byte): 1} + else: + entry[int(next_byte)] = entry.get(int(next_byte), 0) + 1 + + def update(self, history: bytearray, next_byte: int) -> None: + max_k = min(self.order, len(history)) + for k in range(0, max_k + 1): + self._add_count(self._ctx_key(history, k), int(next_byte)) + + +def nn_byte_true_prob_for_token( + p_token_target: float, + token_byte_len: int, + floor: float = 1e-6, +) -> float: + """ + True-byte probability for the spread-root token->byte projection. + The full 256-way distribution is never materialized during PPM eval. + """ + if token_byte_len <= 0: + return 0.0 + L = int(token_byte_len) + p_target = max(min(p_token_target, 1.0), 1e-12) ** (1.0 / L) + peak = max(float(p_target), float(floor)) + return peak / (peak + 255.0 * float(floor)) + + +def mix_true_byte_prob( + nn_p_true: float, + ppm_p_true: float, + ppm_conf: float, + conf_threshold: float, + lambda_lo: float, + lambda_hi: float, +) -> tuple[float, float]: + lam = lambda_lo if ppm_conf >= conf_threshold else lambda_hi + return (1.0 - lam) * nn_p_true + lam * ppm_p_true, lam + + +@torch.no_grad() +def eval_val_with_ppm_mixture( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + token_piece_bytes: list[bytes], + has_leading_space: list[bool], + is_boundary_token: list[bool], + log0=print, +): + """ + Experimental byte-level mixture eval: + neural token LM + lightweight byte-level PPM. + Single-rank only for now. + """ + if world_size != 1: + raise NotImplementedError("PPM mixture eval currently supports single-rank only.") + + if args.nn_byte_projection != "spread_root": + raise ValueError(f"Unsupported NN_BYTE_PROJECTION: {args.nn_byte_projection}") + + model.eval() + + seq_len = args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + subset_limit = args.ppm_subset_tokens if args.ppm_subset_tokens > 0 else total_seqs * seq_len + + ppm = BytePPM(order=args.ppm_order, max_contexts=args.ppm_max_contexts) + history = bytearray() + + total_nll = 0.0 + total_bytes = 0 + total_tokens_seen = 0 + + for seq_idx in range(total_seqs): + if total_tokens_seen >= subset_limit: + break + + raw_start = seq_idx * seq_len + raw_end = raw_start + seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + + x = local[:-1].reshape(1, seq_len) + y = local[1:].reshape(1, seq_len) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = model.forward_logits(x) + + probs = torch.softmax(logits.float(), dim=-1)[0] # [T, V] + target_probs = probs.gather(1, y[0].unsqueeze(1)).squeeze(1).detach().cpu().tolist() + x_ids = x[0].tolist() + y_ids = y[0].tolist() + + for t, (prev_id, tgt_id) in enumerate(zip(x_ids, y_ids)): + if total_tokens_seen >= subset_limit: + break + + p_tok = float(target_probs[t]) + tok_bytes = reconstruct_token_bytes( + prev_token_id=prev_id, + token_id=tgt_id, + token_piece_bytes=token_piece_bytes, + has_leading_space=has_leading_space, + is_boundary_token=is_boundary_token, + ) + + nn_p_true = nn_byte_true_prob_for_token( + p_token_target=p_tok, + token_byte_len=len(tok_bytes), + floor=args.nn_byte_uniform_floor, + ) + + for true_b in tok_bytes: + ppm_p_true, ppm_conf = ppm.predict_true_and_conf(history, true_b) + p_true, _ = mix_true_byte_prob( + nn_p_true=nn_p_true, + ppm_p_true=ppm_p_true, + ppm_conf=ppm_conf, + conf_threshold=args.ppm_conf_threshold, + lambda_lo=args.lambda_lo, + lambda_hi=args.lambda_hi, + ) + + total_nll += -math.log(max(p_true, 1e-12)) + total_bytes += 1 + + # update AFTER scoring + ppm.update(history, true_b) + history.append(true_b) + + total_tokens_seen += 1 + + if seq_idx > 0 and seq_idx % 500 == 0: + mix_bpb = total_nll / (math.log(2.0) * max(total_bytes, 1)) + log0( + f"ppm_mix_progress seq:{seq_idx}/{total_seqs} " + f"tokens:{total_tokens_seen} bytes:{total_bytes} " + f"contexts:{ppm.context_count} skipped_ctx:{ppm.skipped_new_contexts} " + f"mix_bpb:{mix_bpb:.6f}" + ) + + mix_bpb = total_nll / (math.log(2.0) * max(total_bytes, 1)) + return mix_bpb + +# ============================================================ +# LEGAL SCORE-FIRST LoRA-TTT +# ============================================================ + +def build_ttt_optimizer(args: Hyperparameters, lora_mgr: LoRATTTManager): + params = list(lora_mgr.lora_parameters()) + if not params: + raise RuntimeError("No LoRA params found for TTT") + return torch.optim.SGD( + params, + lr=args.ttt_lr, + momentum=args.ttt_momentum, + weight_decay=args.ttt_weight_decay, + ) + + +def eval_val_sliding_lora_ttt( + args: Hyperparameters, + base_model: nn.Module, + lora_mgr: LoRATTTManager, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + log0=print, +): + """ + Legal score-first TTT. + + Phase 1: score a chunk under inference_mode. + Phase 2: update LoRA on that already-scored chunk. + Last chunk is never trained after scoring. + """ + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + window_starts = list(range(0, total_tokens - seq_len + 1, seq_len)) + + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + ci = min(ws // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + lora_mgr.set_enabled(True) + optimizer = build_ttt_optimizer(args, lora_mgr) + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), args.ttt_batch_seqs): + batch_ws = my_windows[bi: bi + args.ttt_batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + valid_lens = [] + for i_w, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + tok = val_tokens[ws: end + 1].to(dtype=torch.int64, device=device) + x_batch[i_w, :wlen] = tok[:-1] + y_batch[i_w, :wlen] = tok[1:] + valid_lens.append(wlen) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + _, nll = base_model(x_batch, y_batch, reduction="none") + for i_w, wlen in enumerate(valid_lens): + scored_nll = nll[i_w, :wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen) + tgt = y_batch[i_w, :wlen] + prev = x_batch[i_w, :wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + is_last = ci == num_chunks - 1 + if not is_last and args.ttt_epochs > 0: + lora_mgr.reset_chunk_state() + base_model.train() + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg["lr"] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_seqs) + start_tok = chunk_start + (my_seq_s + bs) * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok: end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in lora_mgr.lora_parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(list(lora_mgr.lora_parameters()), args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 20 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) + log0(f" lora_ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + lora_mgr.set_enabled(False) + base_model.eval() + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + log0(f"lora_ttt:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} elapsed={time.perf_counter()-t0:.1f}s") + return val_loss, val_bpb + + +# ============================================================ +# QUANTIZATION (minimal placeholder: keep your v4 implementation) +# ============================================================ + +# ============================================================ +# INT8 + ZLIB SERIALIZATION +# ============================================================ + +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + p for p in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") if p +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +DYNAMIC_CLIP_Q_LIST = [ + float(p) / 100.0 + for p in os.environ.get( + "DYNAMIC_CLIP_PERCENTILES", + "100.0,99.9999,99.9995,99.995,99.99,99.95,99.9,99.8", + ).split(",") + if p.strip() +] + + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_scale, best_mse = None, None, float("inf") + for q_pct in DYNAMIC_CLIP_Q_LIST: + if q_pct >= 1.0: + clip_abs = t32.abs().max(dim=1).values + else: + clip_abs = ( + torch.quantile(t32.abs(), q_pct, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32, device=t32.device) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8) + mse = F.mse_loss(q.float() * scale[:, None], t32).item() + if best_q is None or mse < best_mse: + best_mse, best_q, best_scale = mse, q, scale + return best_q.contiguous(), best_scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", + "baseline_tensor_bytes", "int8_payload_bytes"), 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def dequantize_state_dict_int8(obj: dict) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + out[name] = (q.float() * float(s.item())).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def compute_submission_size(state_dict: dict[str, Tensor], code: str) -> tuple[int, int, dict]: + """Return (quant_file_bytes, code_bytes, stats) without writing to disk.""" + quant_obj, stats = quantize_state_dict_int8(state_dict) + buf = io.BytesIO() + torch.save(quant_obj, buf) + blob = zlib.compress(buf.getvalue(), level=9) + code_bytes = len(code.encode("utf-8")) + return len(blob), code_bytes, stats + + +def serialize_model(base_model: nn.Module, output_dir: str, code: str, log0) -> tuple[str, dict]: + """Save final_model.pt (raw float) + final_model.int8.ptz (int8+zlib), log sizes.""" + # Raw checkpoint — used by eval_only_main and for debugging. + raw_path = os.path.join(output_dir, "final_model.pt") + torch.save(base_model.state_dict(), raw_path) + log0(f"saved raw checkpoint: {raw_path} ({os.path.getsize(raw_path):,} bytes)") + + # Quantized + compressed artifact. + quant_obj, stats = quantize_state_dict_int8(base_model.state_dict()) + buf = io.BytesIO() + torch.save(quant_obj, buf) + blob = zlib.compress(buf.getvalue(), level=9) + quant_path = os.path.join(output_dir, "final_model.int8.ptz") + with open(quant_path, "wb") as f: + f.write(blob) + quant_file_bytes = os.path.getsize(quant_path) + code_bytes = len(code.encode("utf-8")) + ratio = stats["baseline_tensor_bytes"] / max(stats["int8_payload_bytes"], 1) + limit_bytes = 16 * 1024 * 1024 # 16 MB + total_bytes = quant_file_bytes + code_bytes + log0( + f"model_size int8+zlib:{quant_file_bytes} bytes " + f"code:{code_bytes} bytes total:{total_bytes} bytes " + f"limit:16MB({limit_bytes}) " + f"{'FITS' if total_bytes <= limit_bytes else 'EXCEEDS_LIMIT'}" + ) + log0( + f" payload:{stats['int8_payload_bytes']} " + f"raw_torch:{buf.tell()} compression_ratio:{ratio:.2f}x" + ) + return quant_path, stats + + +# ============================================================ +# MAIN +# ============================================================ +def assert_all_params_on_device(module: nn.Module, device: torch.device) -> None: + bad = [] + for name, p in module.named_parameters(): + if p.device != device: + bad.append((name, str(p.device), str(p.dtype), tuple(p.shape))) + if bad: + lines = ["Parameters on wrong device:"] + for name, dev, dtype, shape in bad[:100]: + lines.append(f" {name}: device={dev} dtype={dtype} shape={shape}") + raise RuntimeError("\n".join(lines)) + +def is_lora_param(name: str) -> bool: + name = name.lower() + return ("lora_" in name) or (".lora." in name) or ("lora_a" in name) or ("lora_b" in name) + + +def collect_parallel_v2_telemetry(base_model: GPT) -> dict: + args = base_model.args + layers = [] + total_second_lane_params = 0 + for i, block in enumerate(base_model.blocks): + if not getattr(block, "parallel_v2_enabled", False): + continue + second_lane = block.mlp if block.parallel_v2_second_lane_name == "mlp" else block.second_lane + second_lane_params = count_trainable_params(second_lane) + total_second_lane_params += second_lane_params + layer = { + "layer": i, + "mode": block.parallel_v2_mode, + "second_lane": block.parallel_v2_second_lane_name, + "second_lane_params": second_lane_params, + "attn_scale_mean": float(block.parallel_v2_attn_scale.detach().float().mean().item()), + "second_scale_mean": float(block.parallel_v2_second_scale.detach().float().mean().item()), + } + if hasattr(block, "parallel_v2_gate"): + layer["gate_mean"] = float(torch.sigmoid(block.parallel_v2_gate.detach().float()).mean().item()) + if hasattr(block, "parallel_v2_attn_norm_ratio"): + attn_ratio = block.parallel_v2_attn_norm_ratio.detach().float() + second_ratio = block.parallel_v2_second_norm_ratio.detach().float() + if bool(torch.isfinite(attn_ratio).item()): + layer["attn_norm_ratio"] = float(attn_ratio.item()) + if bool(torch.isfinite(second_ratio).item()): + layer["second_norm_ratio"] = float(second_ratio.item()) + layers.append(layer) + return { + "parallel_v2_enabled": int(args.parallel_v2_enabled), + "parallel_v2_mode": args.parallel_v2_mode, + "parallel_v2_second_lane": args.parallel_v2_second_lane, + "parallel_v2_last_n_layers": args.parallel_v2_last_n_layers, + "parallel_v2_active_layers": [layer["layer"] for layer in layers], + "parallel_v2_second_lane_params_total": total_second_lane_params, + "parallel_v2_layers": layers, + } + + +def write_jsonl(path: str, payload: dict) -> None: + parent = os.path.dirname(path) + if parent: + os.makedirs(parent, exist_ok=True) + with open(path, "a", encoding="utf-8") as f: + f.write(json.dumps(payload, sort_keys=True) + "\n") + + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + grad_accum_steps = int(os.environ.get("GRAD_ACCUM_STEPS", 8 // max(world_size, 1))) + grad_scale = 1.0 / grad_accum_steps + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + outdir = args.output_dir or "logs" + os.makedirs(outdir, exist_ok=True) + logfile = os.path.join(outdir, f"{time.strftime('%Y%m%d_%H%M%S')}.txt") + print(logfile) + + def log0(msg: str, console: bool = True): + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"SentencePiece .model expected: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"VOCAB_SIZE={args.vocab_size} != tokenizer vocab_size={int(sp.vocab_size())}") + + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + + token_piece_bytes, sp_has_leading_space, sp_is_boundary_token = build_sentencepiece_byte_tables( + sp, args.vocab_size + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # base_model = GPT(args, master_process=master_process).to(device).bfloat16() + # for module in base_model.modules(): + # if isinstance(module, CastedLinear): + # module.float() + # restore_low_dim_params_to_fp32(base_model) + + # # Inject LoRA wrappers BEFORE torch.compile. + # lora_mgr = None + # if args.ttt_enabled and args.ttt_mode == "lora" and args.lora_ttt_enabled: + # lora_mgr = LoRATTTManager(base_model, args) + # lora_mgr.inject() + # if master_process: + # print(f"TTT: LoRA adapters injected for targets={args.lora_ttt_targets}") + + base_model = GPT(args, master_process=master_process) + + # Inject LoRA wrappers BEFORE moving model to device / dtype and BEFORE torch.compile. + lora_mgr = None + if args.ttt_enabled and args.ttt_mode == "lora" and args.lora_ttt_enabled: + lora_mgr = LoRATTTManager(base_model, args) + lora_mgr.inject() + if master_process: + print(f"TTT: LoRA adapters injected for targets={args.lora_ttt_targets}") + + base_model = base_model.to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + assert_all_params_on_device(base_model, device) + + + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 + and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + and not is_lora_param(name) + and p.requires_grad + ] + + scalar_params = [ + p + for name, p in block_named_params + if (p.ndim != 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) + and not is_lora_param(name) + and p.requires_grad + ] + + lora_params = [ + p + for name, p in base_model.named_parameters() + if is_lora_param(name) + and p.requires_grad + ] + + if master_process: + print(f"lora_params:{sum(p.numel() for p in lora_params)}") + + if hasattr(base_model, "skip_weights") and base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if hasattr(base_model, "bifpn_weights") and base_model.bifpn_weights.numel() > 0: + scalar_params.append(base_model.bifpn_weights) + if hasattr(base_model, "structured_bifpn"): + scalar_params.append(base_model.structured_bifpn.weights) + if hasattr(base_model, "mtp_heads") and base_model.mtp_heads is not None: + for p in base_model.mtp_heads.parameters(): + scalar_params.append(p) + if hasattr(base_model, "ngram") and base_model.ngram is not None: + if base_model.ngram.proj is not None: + scalar_params.append(base_model.ngram.proj.weight) + scalar_params.append(base_model.ngram.ngram_scales) + if hasattr(base_model, "ple") and getattr(base_model.ple, "enabled", False): + scalar_params.extend(list(base_model.ple.parameters())) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_param_groups = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if hasattr(base_model, "ngram") and base_model.ngram is not None: + for emb in base_model.ngram.embeds: + tok_param_groups.append({"params": [emb.weight], "lr": token_lr, "base_lr": token_lr}) + + optimizer_tok = torch.optim.Adam(tok_param_groups, betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizer_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, backend_steps=args.muon_backend_steps) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0(f"tie_embeddings:{args.tie_embeddings} matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}") + log0(f"ttt_enabled:{args.ttt_enabled} ttt_mode:{args.ttt_mode} lora_ttt_enabled:{args.lora_ttt_enabled}") + parallel_v2_snapshot = collect_parallel_v2_telemetry(base_model) + log0( + f"parallel_v2_enabled:{parallel_v2_snapshot['parallel_v2_enabled']} " + f"mode:{parallel_v2_snapshot['parallel_v2_mode']} " + f"second_lane:{parallel_v2_snapshot['parallel_v2_second_lane']} " + f"active_layers:{parallel_v2_snapshot['parallel_v2_active_layers']} " + f"second_lane_params:{parallel_v2_snapshot['parallel_v2_second_lane_params_total']}" + ) + if args.parallel_v2_enabled: + for layer in parallel_v2_snapshot["parallel_v2_layers"]: + gate_msg = f" gate_mean:{layer['gate_mean']:.6f}" if "gate_mean" in layer else "" + log0( + f"parallel_v2_layer:{layer['layer']} " + f"attn_scale_mean:{layer['attn_scale_mean']:.6f} " + f"second_scale_mean:{layer['second_scale_mean']:.6f}" + f"{gate_msg} params:{layer['second_lane_params']}" + ) + if master_process and args.telemetry_every > 0: + write_jsonl(args.telemetry_file, {"event": "init", **parallel_v2_snapshot}) + if args.ttt_enabled and args.lora_ttt_enabled: + log0( + f"lora_ttt_rank:{args.lora_ttt_rank} alpha:{args.lora_ttt_alpha} " + f"warmA:{int(args.lora_ttt_warm_start_a)} resetB:{int(args.lora_ttt_reset_b_each_chunk)} " + f"chunk:{args.ttt_chunk_tokens} wd:{args.ttt_weight_decay}" + ) + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all(): + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + use_walltime_stop = args.stop_mode == "walltime" + use_steps_stop = args.stop_mode == "steps" + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if (use_walltime_stop and args.max_wallclock_seconds > 0) else None + hard_step_limit = args.max_train_steps if (use_steps_stop and args.max_train_steps > 0) else args.iterations + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if use_steps_stop: + warmdown_start = max(hard_step_limit - args.warmdown_iters, 0) + return max((hard_step_limit - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < hard_step_limit else 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * (1.0 / grad_accum_steps)).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + ema_state = None + if args.ema_enabled: + log0(f"EMA Enabled: decay={args.ema_decay}") + ema_state = {name: p.detach().float().clone() for name, p in base_model.state_dict().items()} + ema_tensors_list = list(ema_state.values()) + model_tensors_list = list(base_model.state_dict().values()) + + qat_start_step = int(hard_step_limit * (1.0 - args.late_qat_ratio)) + if args.late_qat_ratio > 0: + log0(f"Scheduled Late QAT to start at step {qat_start_step} (last {args.late_qat_ratio*100:.1f}%)") + + step = 0 + muon_momentum = args.muon_momentum_warmup_start if args.muon_momentum_warmup_steps > 0 else args.muon_momentum + + while True: + last_step = step == hard_step_limit or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + eval_fn = eval_val_sliding if args.eval_use_sliding_window else eval_val + eval_model = base_model if args.eval_use_sliding_window else model + val_loss, val_bpb = eval_fn(args, eval_model, rank, world_size, device, grad_accum_steps, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"step:{step}/{hard_step_limit} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms") + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + break + + if step == qat_start_step and args.late_qat_ratio > 0.0: + log0(f"[Step {step}] Activating Late QAT — enabling branchless STE quantization.") + for mod in base_model.modules(): + if isinstance(mod, CastedLinear): + mod.qat_alpha.fill_(1.0) + + step_t0 = time.perf_counter() + elapsed_ms = training_time_ms + 1000.0 * (step_t0 - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + last_telemetry_x: Tensor | None = None + last_telemetry_y: Tensor | None = None + + ngram_global_scale = compute_ngram_fade_scale( + step=step, + total_steps=hard_step_limit, + enabled=args.ngram_fade_enable, + start_frac=args.ngram_fade_start_frac, + end_frac=args.ngram_fade_end_frac, + min_scale=args.ngram_fade_min_scale, + ) + + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + base_model.ngram_global_scale_buf.fill_(float(ngram_global_scale)) + loss = model(x, y) + last_telemetry_x = x.detach() + last_telemetry_y = y.detach() + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + if args.ema_enabled and ema_state is not None: + with torch.no_grad(): + update_ema_fused(ema_tensors_list, model_tensors_list, args.ema_decay) + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step}/{hard_step_limit} train_loss:{train_loss.item():.4f} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms") + + if master_process and args.telemetry_every > 0 and step % args.telemetry_every == 0: + if args.parallel_v2_enabled and args.parallel_v2_log_norm_ratios and last_telemetry_x is not None and last_telemetry_y is not None: + was_training = base_model.training + base_model.set_parallel_v2_norm_capture(True) + try: + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + base_model(last_telemetry_x, last_telemetry_y) + finally: + base_model.set_parallel_v2_norm_capture(False) + base_model.train(was_training) + write_jsonl( + args.telemetry_file, + { + "event": "train", + "step": step, + "train_loss": float(train_loss.item()), + "lr_scale": float(scale), + **collect_parallel_v2_telemetry(base_model), + }, + ) + + if use_steps_stop: + reached_cap = step >= hard_step_limit + else: + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB") + + if args.ema_enabled and ema_state is not None: + log0("Applying EMA weights for final evaluation...") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + + # ----------------------------- + # SERIALIZATION + SIZE REPORT + ROUNDTRIP VALIDATION + # ----------------------------- + if master_process: + serialize_model(base_model, outdir, code, log0) + if distributed: + dist.barrier() + + # Roundtrip: reload quantized weights, run eval to measure degradation. + quant_path_rt = os.path.join(outdir, "final_model.int8.ptz") if master_process else os.path.join(args.output_dir or "logs", "final_model.int8.ptz") + with open(quant_path_rt, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if args.ppm_enabled and rank == 0: + log0("Starting PPM byte mixture evaluation...") + ppm_mix_bpb = eval_val_with_ppm_mixture( + args=args, + model=base_model, + rank=rank, + world_size=world_size, + device=device, + val_tokens=val_tokens, + token_piece_bytes=token_piece_bytes, + has_leading_space=sp_has_leading_space, + is_boundary_token=sp_is_boundary_token, + log0=log0, + ) + log0(f"ppm_mix_bpb:{ppm_mix_bpb:.6f}") + + if args.ttt_enabled and args.ttt_mode == "lora" and lora_mgr is not None: + log0("Starting legal LoRA-TTT evaluation...") + ttt_val_loss, ttt_val_bpb = eval_val_sliding_lora_ttt( + args=args, + base_model=base_model, + lora_mgr=lora_mgr, + rank=rank, + world_size=world_size, + device=device, + val_tokens=val_tokens, + base_bytes_lut=base_bytes_lut, + has_leading_space_lut=has_leading_space_lut, + is_boundary_token_lut=is_boundary_token_lut, + log0=log0, + ) + log0(f"lora_ttt_final val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f}") + log0(f"lora_ttt_final_exact val_loss:{ttt_val_loss:.8f} val_bpb:{ttt_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +def eval_only_main() -> None: + """ + Eval-only mode: load a checkpoint, report int8+zlib model size, and run + the standard val evaluation — no training. + + Usage: + EVAL_ONLY=1 CHECKPOINT=path/to/final_model.pt OUTPUT_DIR=logs/eval \ + python mytrain_gpt_v6.py + + The checkpoint must be a raw state_dict saved with torch.save(). + All model-architecture env vars must match the original training run. + """ + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + + rank = int(os.environ.get("RANK", "0")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + outdir = args.output_dir or "logs/eval" + os.makedirs(outdir, exist_ok=True) + logfile = os.path.join(outdir, f"eval_{time.strftime('%Y%m%d_%H%M%S')}.txt") + + def log0(msg: str, console: bool = True): + if not master_process: + return + if console: + print(msg) + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + checkpoint_path = os.environ.get("CHECKPOINT", "") + if not checkpoint_path: + raise ValueError("EVAL_ONLY=1 requires CHECKPOINT= to be set") + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + log0(f"eval_only checkpoint:{checkpoint_path}") + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"VOCAB_SIZE={args.vocab_size} != tokenizer vocab_size={int(sp.vocab_size())}") + + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + token_piece_bytes, sp_has_leading_space, sp_is_boundary_token = build_sentencepiece_byte_tables( + sp, args.vocab_size + ) + + base_model = GPT(args, master_process=master_process).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + + state = torch.load(checkpoint_path, map_location="cpu") + # Support both raw state_dict and {"model": state_dict} wrapping. + if "model" in state and isinstance(state["model"], dict): + state = state["model"] + base_model.load_state_dict(state, strict=True) + log0(f"loaded checkpoint: {checkpoint_path}") + + total_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{total_params:,}") + + # ---------- size report ---------- + quant_file_bytes, code_bytes, stats = compute_submission_size(base_model.state_dict(), code) + limit_bytes = 16 * 1024 * 1024 + total_bytes = quant_file_bytes + code_bytes + log0( + f"submission_size int8+zlib:{quant_file_bytes} bytes " + f"code:{code_bytes} bytes total:{total_bytes} bytes " + f"limit:16MB({limit_bytes}) " + f"{'FITS' if total_bytes <= limit_bytes else 'EXCEEDS_LIMIT'}" + ) + log0(f" params:{stats['param_count']:,} " + f"float_tensors:{stats['num_float_tensors']} " + f"baseline_bytes:{stats['baseline_tensor_bytes']} " + f"int8_payload:{stats['int8_payload_bytes']} " + f"compression:{stats['baseline_tensor_bytes']/max(stats['int8_payload_bytes'],1):.2f}x") + + # ---------- val eval (fp weights) ---------- + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + grad_accum_steps = int(os.environ.get("GRAD_ACCUM_STEPS", "1")) + t0 = time.perf_counter() + val_loss, val_bpb = eval_val( + args, compiled_model, rank, 1, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"eval_fp val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"eval_time:{1000.0*(time.perf_counter()-t0):.0f}ms" + ) + log0(f"eval_fp_exact val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f}") + + # ---------- int8 roundtrip eval ---------- + quant_obj, _ = quantize_state_dict_int8(base_model.state_dict()) + buf = io.BytesIO() + torch.save(quant_obj, buf) + blob = zlib.compress(buf.getvalue(), level=9) + quant_state = torch.load(io.BytesIO(zlib.decompress(blob)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t0 = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_model, rank, 1, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"eval_int8_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0*(time.perf_counter()-t0):.0f}ms" + ) + log0(f"eval_int8_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if args.ppm_enabled: + ppm_bpb = eval_val_with_ppm_mixture( + args=args, model=base_model, rank=rank, world_size=1, + device=device, val_tokens=val_tokens, + token_piece_bytes=token_piece_bytes, + has_leading_space=sp_has_leading_space, + is_boundary_token=sp_is_boundary_token, + log0=log0, + ) + log0(f"ppm_mix_bpb:{ppm_bpb:.6f}") + + log0(f"eval_only done logfile:{logfile}") + + +if __name__ == "__main__": + if os.environ.get("EVAL_ONLY", "0") == "1": + eval_only_main() + else: + main() + +""" +python launchv3.py sweep_ppm_mixture_v1.json \ + --train-script mytrain_gpt_v6.py \ + --output output/run_sweep_ppm_mixture_v1 \ + --stop-mode steps \ + --max-steps 3000 \ + --no-analysis +"""