diff --git a/records/track_10min_16mb/2026-03-21_DominationV3/README.md b/records/track_10min_16mb/2026-03-21_DominationV3/README.md new file mode 100644 index 0000000000..b6c45ca1b6 --- /dev/null +++ b/records/track_10min_16mb/2026-03-21_DominationV3/README.md @@ -0,0 +1,42 @@ +# DominationV3: 11L EMA + Partial RoPE + LN Scale + GPTQ-lite + TTT(25ep) + +**Mean val_bpb: 1.12495** (3 seeds) | **Best: 1.12431** | 8xH100 SXM + +## Key Techniques + +1. **11-layer GPT** with 512 model dim, 8 heads, 4 KV heads, MLP 3x. +2. **Partial RoPE** (16/64 dims): RoPE on 25% of head dims; rest position-free. +3. **LN Scale** (`1/sqrt(layer_idx+1)`): Damp deeper layer norm outputs. +4. **EMA averaging** (decay=0.997). +5. **BigramHash(4096x128)** for local context. +6. **GPTQ-lite quantization**: Per-row optimal clip percentile search (5 candidates) minimizing reconstruction MSE. +7. **Mixed int6 quantization** on `mlp`, `attn`, and `tok_emb` + zstd-22. +8. **25-epoch aggressive SGD TTT** (lr=0.012, momentum=0.9, ALL blocks unfrozen) on already-graded tokens. +9. **XSA disabled** to save ~1.4ms/step for more training steps. +10. **Sliding-window evaluation** (stride=64). + +## Compliance + +- Trains only on `fineweb_train_*` shards (80 shards). +- TTT runs at eval time on the quantized model, adapting only to tokens already scored. +- Training capped to 599.8s. TTT ~389s + sliding eval ~197s = ~586s total eval (under 10 min). +- All artifacts under 16,000,000 bytes. + +## Results (3 seeds, 8xH100 SXM) + +| Seed | val_bpb | train_time_ms | ttt_time_ms | total_artifact_bytes | +|------|---------|---------------|-------------|----------------------| +| 1337 | 1.12513674 | 599779 | 389133 | 15965664 | +| 7 | 1.12540132 | 599841 | ~389000 | 15829190 | +| 42 | **1.12431423** | 599822 | ~389000 | 15806256 | + +**Mean:** 1.12495076 +**Stddev:** 0.00056691 + +## Repro + +```bash +modal run records/track_10min_16mb/2026-03-21_DominationV3/run_modal.py \ + --mode standard --profile domv3 --seed 1337 --bigram-vocab 4096 \ + --extra-env "FP16_PASSTHROUGH_PATTERNS=;MIXED_QUANT_INT6_CATS=mlp,attn,tok_emb;MAX_WALLCLOCK_SECONDS=599.8;XSA_LAST_N=0;ROPE_DIMS=16;LN_SCALE=1;TTT_ENABLED=1;TTT_P1_EPOCHS=0;TTT_EPOCHS=25;TTT_LR=0.012;TTT_MOMENTUM=0.9;TTT_FREEZE_BLOCKS=0" +``` diff --git a/records/track_10min_16mb/2026-03-21_DominationV3/run_modal.py b/records/track_10min_16mb/2026-03-21_DominationV3/run_modal.py new file mode 100644 index 0000000000..a405b96bd7 --- /dev/null +++ b/records/track_10min_16mb/2026-03-21_DominationV3/run_modal.py @@ -0,0 +1,529 @@ +""" +Modal deployment script for parameter-golf training on 8xH100 SXM. + +Usage: + Standard run: modal run run_modal.py --mode standard --profile domv1 --seeds 1337,42,7 + Val-only run: modal run run_modal.py --mode valonly --profile domv1 --seeds 1337,42,7 + Architecture: modal run run_modal.py --mode standard --profile domv1 --num-layers 12 --bigram-vocab 2048 --tag 12L_bg2k + HP sweep: modal run run_modal.py --mode standard --profile domv1 --muon-wd 0.05 --matrix-lr 0.03 --tag wd05_lr03 + Batch sweep: modal run run_modal.py --mode standard --profile domv1 --extra-env "TRAIN_BATCH_TOKENS=393216" --tag batch393k + +This will: + - Download the FineWeb dataset inside Modal (cached in a Volume) + - Run torchrun with 8xH100 GPUs for 10 minutes + - Download train.log and model artifacts to your local machine +""" + +import modal + +app = modal.App("parameter-golf") + +data_vol = modal.Volume.from_name("parameter-golf-data", create_if_missing=True) +output_vol = modal.Volume.from_name("parameter-golf-output", create_if_missing=True) + +TRAIN_SCRIPT_DOMV3 = "records/track_10min_16mb/2026-03-21_DominationV3/train_gpt.py" +TRAIN_SCRIPT_DOMV4 = "records/track_10min_16mb/2026-03-21_DominationV4/train_gpt.py" + +image = ( + modal.Image.debian_slim(python_version="3.12") + .pip_install( + "torch==2.10", + "numpy", + "sentencepiece", + "huggingface-hub", + "tqdm", + "setuptools", + "typing-extensions==4.15.0", + "zstandard", + extra_options="--extra-index-url https://download.pytorch.org/whl/cu124", + ) + .add_local_file("data/cached_challenge_fineweb.py", "/root/data/cached_challenge_fineweb.py") + .add_local_file(TRAIN_SCRIPT_DOMV3, "/root/train_gpt_domv3.py") + .add_local_file(TRAIN_SCRIPT_DOMV4, "/root/train_gpt_domv4.py") +) + + +def _parse_extra_env(raw: str) -> dict[str, str]: + out = {} + if not raw: + return out + for item in raw.split(";"): + item = item.strip() + if not item: + continue + if "=" not in item: + raise ValueError(f"Invalid extra_env entry (expected KEY=VALUE): {item}") + k, v = item.split("=", 1) + out[k.strip()] = v.strip() + return out + + +def _profile_env(mode: str, profile: str) -> dict[str, str]: + if profile == "domv4": + return { + "VAL_LOSS_EVERY": "1000", + "TRAIN_LOG_EVERY": "200", + "TRAIN_SEQ_LEN": "2048", + "EVAL_SEQ_LEN": "2048", + "TRAIN_BATCH_TOKENS": "786432", + "NUM_LAYERS": "11", + "MIXED_QUANT_INT6_CATS": "mlp,attn", + "FP16_PASSTHROUGH_PATTERNS": "tok_emb,blocks.10.attn.c_k", + "MTP_NUM_HEADS": "2", + "MTP_LOSS_WEIGHT": "0.2", + "XSA_LAST_N": "4", + "EMA_ENABLED": "1", + "EMA_DECAY": "0.997", + "SWA_ENABLED": "0", + "QAT_ENABLED": "0", + "MUON_WD": "0.02", + "ADAM_WD": "0.01", + "MATRIX_LR": "0.04", + "SCALAR_LR": "0.04", + "TIED_EMBED_LR": "0.05", + "MUON_MOMENTUM": "0.95", + "MUON_MOMENTUM_WARMUP_START": "0.85", + "MUON_MOMENTUM_WARMUP_STEPS": "500", + "WARMDOWN_ITERS": "1200", + "BIGRAM_VOCAB_SIZE": "4096", + "BIGRAM_DIM": "128", + "EVAL_STRIDE": "64", + } + + if profile == "domv3": + return { + "VAL_LOSS_EVERY": "500", + "TRAIN_LOG_EVERY": "100", + "TRAIN_SEQ_LEN": "2048", + "TRAIN_BATCH_TOKENS": "524288", + "NUM_LAYERS": "11", + "SWA_ENABLED": "0", + "EMA_ENABLED": "1", + "EMA_DECAY": "0.997", + "XSA_LAST_N": "4", + "MIXED_QUANT_INT6_CATS": "mlp,attn,tok_emb", + "WEIGHT_DECAY": "0.04", + "MUON_WD": "0.04", + "MATRIX_LR": "0.025", + "SCALAR_LR": "0.025", + "TIED_EMBED_LR": "0.035", + "MUON_MOMENTUM": "0.99", + "MUON_MOMENTUM_WARMUP_START": "0.92", + "MUON_MOMENTUM_WARMUP_STEPS": "1500", + "WARMDOWN_ITERS": "3000", + "BIGRAM_VOCAB_SIZE": "4096", + "BIGRAM_DIM": "128", + "EVAL_STRIDE": "64", + "EVAL_BATCH_SEQS": "32", + "MAX_WALLCLOCK_SECONDS": "599.8", + "FP16_PASSTHROUGH_PATTERNS": "", + "STE_QAT_ENABLED": "0", + "TTT_ENABLED": "0", + } + + if profile == "domv2": + return { + "VAL_LOSS_EVERY": "500", + "TRAIN_LOG_EVERY": "100", + "TRAIN_SEQ_LEN": "2048", + "TRAIN_BATCH_TOKENS": "524288", + "NUM_LAYERS": "11", + "SWA_ENABLED": "0", + "EMA_ENABLED": "1", + "EMA_DECAY": "0.997", + "XSA_LAST_N": "4", + "MIXED_QUANT_INT6_CATS": "mlp,attn", + "WEIGHT_DECAY": "0.04", + "MUON_WD": "0.04", + "MATRIX_LR": "0.025", + "SCALAR_LR": "0.025", + "TIED_EMBED_LR": "0.035", + "MUON_MOMENTUM": "0.99", + "MUON_MOMENTUM_WARMUP_START": "0.92", + "MUON_MOMENTUM_WARMUP_STEPS": "1500", + "WARMDOWN_ITERS": "3000", + "BIGRAM_VOCAB_SIZE": "2048", + "BIGRAM_DIM": "128", + "EVAL_STRIDE": "64", + "EVAL_BATCH_SEQS": "32", + "STE_QAT_ENABLED": "0", + "TTT_ENABLED": "0", + } + + if profile == "domv1": + base = { + "VAL_LOSS_EVERY": "500", + "TRAIN_LOG_EVERY": "100", + "TRAIN_SEQ_LEN": "2048", + "NUM_LAYERS": "11", + "SWA_ENABLED": "1", + "SWA_EVERY": "50", + "SWA_START_FRAC": "0.5", + "MIXED_QUANT_INT6_CATS": "mlp,attn,other", + "INT6_QUANT_RANGE_MLP": "15", + "INT6_QUANT_RANGE_ATTN": "31", + "INT6_QUANT_RANGE_OTHER": "31", + "WEIGHT_DECAY": "0.04", + "MUON_WD": "0.04", + "MATRIX_LR": "0.025", + "SCALAR_LR": "0.025", + "TIED_EMBED_LR": "0.035", + "MUON_MOMENTUM": "0.99", + "MUON_MOMENTUM_WARMUP_START": "0.92", + "MUON_MOMENTUM_WARMUP_STEPS": "1500", + "WARMDOWN_ITERS": "3000", + "BIGRAM_VOCAB_SIZE": "4096", + "BIGRAM_DIM": "128", + "EVAL_STRIDE": "64", + "EVAL_BATCH_SEQS": "32", + } + if mode == "standard": + base.update({ + "TRAIN_BATCH_TOKENS": "524288", + "STE_QAT_ENABLED": "0", + }) + else: + base.update({ + "TRAIN_BATCH_TOKENS": "524288", + "TRAIN_SEQ_LEN": "1024", + "STE_QAT_ENABLED": "1", + "STE_QAT_RANGE": "31", + }) + return base + + if profile in {"counter", "counter_v7"}: + if mode == "standard": + return { + "VAL_LOSS_EVERY": "500", + "TRAIN_LOG_EVERY": "100", + "TRAIN_SEQ_LEN": "2048", + "TRAIN_BATCH_TOKENS": "786432", + "NUM_LAYERS": "10", + "SWA_ENABLED": "1", + "SWA_EVERY": "50", + "SWA_START_FRAC": "0.5", + "STE_QAT_ENABLED": "0", + "MIXED_QUANT_INT6_CATS": "mlp,attn", + "INT6_QUANT_RANGE_MLP": "15", + "INT6_QUANT_RANGE_ATTN": "31", + "INT6_QUANT_RANGE_OTHER": "31", + "FP16_PASSTHROUGH_PATTERNS": "tok_emb,blocks.9.attn.c_k", + "WEIGHT_DECAY": "0.04", + "MUON_WD": "0.04", + } + else: + return { + "VAL_LOSS_EVERY": "500", + "TRAIN_LOG_EVERY": "100", + "TRAIN_SEQ_LEN": "1024", + "TRAIN_BATCH_TOKENS": "524288", + "NUM_LAYERS": "11", + "MATRIX_LR": "0.025", + "SCALAR_LR": "0.025", + "SWA_ENABLED": "0", + "STE_QAT_ENABLED": "0", + "MIXED_QUANT_INT6_CATS": "mlp,attn", + "FP16_PASSTHROUGH_PATTERNS": "tok_emb,blocks.10.attn.c_k", + "WEIGHT_DECAY": "0.034", + "MUON_WD": "0.034", + } + + if profile == "baseline": + return { + "VAL_LOSS_EVERY": "500", + "TRAIN_LOG_EVERY": "100", + "TRAIN_SEQ_LEN": "2048", + "TRAIN_BATCH_TOKENS": "786432", + "NUM_LAYERS": "9", + "SWA_ENABLED": "1", + "SWA_EVERY": "50", + "SWA_START_FRAC": "0.5", + "STE_QAT_ENABLED": "0", + "MIXED_QUANT_INT6_CATS": "mlp,attn", + "INT6_QUANT_RANGE_MLP": "31", + "INT6_QUANT_RANGE_ATTN": "31", + "INT6_QUANT_RANGE_OTHER": "31", + "FP16_PASSTHROUGH_PATTERNS": "tok_emb,blocks.8.attn.c_k", + "WEIGHT_DECAY": "0.04", + "MUON_WD": "0.04", + } + + raise ValueError(f"Unknown profile '{profile}'") + + +def _select_train_script(profile: str, mode: str) -> str: + if profile == "domv4": + return "/root/train_gpt_domv4.py" + return "/root/train_gpt_domv3.py" + + +@app.function( + image=image, + gpu="H100:8", + timeout=45 * 60, + volumes={ + "/data": data_vol, + "/output": output_vol, + }, +) +def train( + mode: str = "standard", + seed: int = 1337, + tag: str = "", + profile: str = "domv1", + num_layers: int = 0, + quant_mode: str = "auto", + muon_wd: float = -1.0, + weight_decay: float = -1.0, + matrix_lr: float = -1.0, + scalar_lr: float = -1.0, + swa_every: int = 0, + bigram_vocab: int = 0, + batch_tokens: int = 0, + rope_base: float = -1.0, + seq_len: int = 0, + extra_env: str = "", +): + import os + import shutil + import subprocess + import sys + + os.chdir("/root") + + data_base = "/data/datasets/fineweb10B_sp1024" + tokenizer_dir = "/data/tokenizers" + val_shard = f"{data_base}/fineweb_val_000000.bin" + is_standard = mode == "standard" + train_shards = "80" if is_standard else "1" + mode_prefix = "standard" if is_standard else "valonly" + run_id_base = f"{mode_prefix}_{profile}_s{seed}" + run_id = f"{run_id_base}_{tag}" if tag else run_id_base + + need_download = not os.path.exists(val_shard) + if is_standard: + need_download = need_download or not os.path.exists(f"{data_base}/fineweb_train_000079.bin") + + if need_download: + print(f"=== Downloading FineWeb data ({train_shards} train shards) ===", flush=True) + subprocess.run( + [ + sys.executable, + "/root/data/cached_challenge_fineweb.py", + "--variant", "sp1024", + "--train-shards", train_shards, + ], + check=True, + env={**os.environ, "PYTHONPATH": "/root"}, + cwd="/root", + ) + local_ds = "/root/data/datasets/fineweb10B_sp1024" + local_tok = "/root/data/tokenizers" + os.makedirs(data_base, exist_ok=True) + os.makedirs(tokenizer_dir, exist_ok=True) + for f in os.listdir(local_ds): + src = os.path.join(local_ds, f) + dst = os.path.join(data_base, f) + if not os.path.exists(dst): + shutil.copy2(src, dst) + for f in os.listdir(local_tok): + src = os.path.join(local_tok, f) + dst = os.path.join(tokenizer_dir, f) + if not os.path.exists(dst): + shutil.copy2(src, dst) + data_vol.commit() + print("=== Data cached to volume ===", flush=True) + else: + print("=== Data already cached ===", flush=True) + + if is_standard: + data_path = data_base + else: + valonly_dir = "/data/datasets/fineweb10B_sp1024_valonly" + os.makedirs(valonly_dir, exist_ok=True) + valonly_train = f"{valonly_dir}/fineweb_train_000000.bin" + valonly_val = f"{valonly_dir}/fineweb_val_000000.bin" + if not os.path.exists(valonly_train): + shutil.copy2(val_shard, valonly_train) + if not os.path.exists(valonly_val): + shutil.copy2(val_shard, valonly_val) + data_vol.commit() + data_path = valonly_dir + + print(f"=== Starting training (mode={mode}, profile={profile}, seed={seed}, data={data_path}) ===", flush=True) + + env = { + **os.environ, + "RUN_ID": run_id, + "SEED": str(seed), + "DATA_PATH": data_path, + "TOKENIZER_PATH": f"{tokenizer_dir}/fineweb_1024_bpe.model", + "VOCAB_SIZE": "1024", + "MAX_WALLCLOCK_SECONDS": "600", + "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", + } + env.update(_profile_env(mode, profile)) + + if num_layers > 0: + env["NUM_LAYERS"] = str(num_layers) + last_idx = num_layers - 1 + fp16_pats = env.get("FP16_PASSTHROUGH_PATTERNS", "tok_emb") + if "blocks." in fp16_pats: + import re + fp16_pats = re.sub(r"blocks\.\d+", f"blocks.{last_idx}", fp16_pats) + env["FP16_PASSTHROUGH_PATTERNS"] = fp16_pats + if muon_wd >= 0: + env["MUON_WD"] = str(muon_wd) + if weight_decay >= 0: + env["WEIGHT_DECAY"] = str(weight_decay) + if matrix_lr >= 0: + env["MATRIX_LR"] = str(matrix_lr) + if scalar_lr >= 0: + env["SCALAR_LR"] = str(scalar_lr) + if swa_every > 0: + env["SWA_ENABLED"] = "1" + env["SWA_EVERY"] = str(swa_every) + if bigram_vocab > 0: + env["BIGRAM_VOCAB_SIZE"] = str(bigram_vocab) + if batch_tokens > 0: + env["TRAIN_BATCH_TOKENS"] = str(batch_tokens) + if rope_base >= 0: + env["ROPE_BASE"] = str(rope_base) + if seq_len > 0: + env["TRAIN_SEQ_LEN"] = str(seq_len) + if quant_mode == "int5mlp": + env.update({ + "INT6_QUANT_RANGE_MLP": "15", + "INT6_QUANT_RANGE_ATTN": "31", + "STE_QAT_ENABLED": "0", + }) + elif quant_mode == "int6all": + env.update({ + "INT6_QUANT_RANGE_MLP": "31", + "INT6_QUANT_RANGE_ATTN": "31", + "INT6_QUANT_RANGE_OTHER": "31", + }) + elif quant_mode != "auto": + raise ValueError(f"Unknown quant_mode: {quant_mode}") + + env.update(_parse_extra_env(extra_env)) + + train_script = _select_train_script(profile, mode) + + result = subprocess.run( + [ + sys.executable, "-m", "torch.distributed.run", + "--standalone", + "--nproc_per_node=8", + train_script, + ], + env=env, + cwd="/root", + stdout=sys.stdout, + stderr=sys.stderr, + ) + + print(f"\n=== Training finished with exit code {result.returncode} ===", flush=True) + + output_base = f"/output/{run_id}" + os.makedirs(output_base, exist_ok=True) + + for fname in ["final_model.pt", "final_model.int8.ptz"]: + src = f"/root/{fname}" + if os.path.exists(src): + shutil.copy2(src, f"{output_base}/{fname}") + print(f" Saved {fname} ({os.path.getsize(src)} bytes)", flush=True) + + log_dir = "/root/logs" + if os.path.isdir(log_dir): + for fname in os.listdir(log_dir): + src = os.path.join(log_dir, fname) + shutil.copy2(src, f"{output_base}/{fname}") + print(f" Saved log: {fname}", flush=True) + + output_vol.commit() + print(f"\n=== All outputs saved to volume 'parameter-golf-output' at /{run_id}/ ===") + + +@app.function( + image=modal.Image.debian_slim(), + volumes={"/output": output_vol}, +) +def download_results(run_id: str = "standard_domv1_s1337"): + import os + + output_base = f"/output/{run_id}" + if not os.path.isdir(output_base): + print(f"No results found for run_id={run_id}.") + return + + for fname in sorted(os.listdir(output_base)): + fpath = os.path.join(output_base, fname) + size = os.path.getsize(fpath) + print(f" {fname}: {size:,} bytes") + + if fname.endswith(".txt"): + print(f"\n--- {fname} contents (last 40 lines) ---") + with open(fpath) as f: + lines = f.readlines() + for line in lines[-40:]: + print(line, end="") + print(f"\n--- end {fname} ---\n") + + +@app.local_entrypoint() +def main( + mode: str = "standard", + profile: str = "domv1", + seed: int = 1337, + seeds: str = "", + tag: str = "", + num_layers: int = 0, + quant_mode: str = "auto", + muon_wd: float = -1.0, + weight_decay: float = -1.0, + matrix_lr: float = -1.0, + scalar_lr: float = -1.0, + swa_every: int = 0, + bigram_vocab: int = 0, + batch_tokens: int = 0, + rope_base: float = -1.0, + seq_len: int = 0, + extra_env: str = "", +): + seed_list = [int(s.strip()) for s in seeds.split(",") if s.strip()] if seeds else [seed] + print( + f"Launching training on Modal 8xH100 SXM (mode={mode}, profile={profile}, " + f"seeds={seed_list}, tag={tag}, quant_mode={quant_mode})..." + ) + print("This will take ~15 minutes (10 min train + ~5 min eval + overhead)\n") + for i, run_seed in enumerate(seed_list): + run_tag = tag + if len(seed_list) > 1 and not run_tag: + run_tag = f"batch{i+1}" + mode_prefix = "standard" if mode == "standard" else "valonly" + run_id_base = f"{mode_prefix}_{profile}_s{run_seed}" + run_id = f"{run_id_base}_{run_tag}" if run_tag else run_id_base + train.remote( + mode=mode, + seed=run_seed, + tag=run_tag, + profile=profile, + num_layers=num_layers, + quant_mode=quant_mode, + muon_wd=muon_wd, + weight_decay=weight_decay, + matrix_lr=matrix_lr, + scalar_lr=scalar_lr, + swa_every=swa_every, + bigram_vocab=bigram_vocab, + batch_tokens=batch_tokens, + rope_base=rope_base, + seq_len=seq_len, + extra_env=extra_env, + ) + print(f"\n=== Fetching results for {run_id} ===\n") + download_results.remote(run_id=run_id) + print(f"\nTo download files locally for {run_id}:") + print(f" modal volume get parameter-golf-output {run_id}/{run_id}.txt ./train.log") + print(f" modal volume get parameter-golf-output {run_id}/final_model.int8.ptz ./final_model.int8.ptz") diff --git a/records/track_10min_16mb/2026-03-21_DominationV3/submission.json b/records/track_10min_16mb/2026-03-21_DominationV3/submission.json new file mode 100644 index 0000000000..5c1242c649 --- /dev/null +++ b/records/track_10min_16mb/2026-03-21_DominationV3/submission.json @@ -0,0 +1,36 @@ +{ + "author": "yesbhautik", + "github_id": "yesbhautik", + "name": "DominationV3: 11L EMA + Partial RoPE + LN Scale + GPTQ-lite + TTT(25ep, lr=0.012)", + "blurb": "11-layer with EMA(0.997), Partial RoPE(16/64), LN Scale, BigramHash(4096x128), GPTQ-lite clip search, int6 quant on mlp+attn+tok_emb, zstd-22, and 25-epoch aggressive SGD TTT (lr=0.012, all blocks unfrozen).", + "date": "2026-03-23T06:00:00Z", + "val_loss": 1.89835246, + "val_bpb": 1.12431423, + "pre_quant_val_loss": 1.9440, + "pre_quant_val_bpb": 1.1513, + "step_stop": 8438, + "wallclock_seconds": 599.8, + "eval_time_seconds": 585.0, + "bytes_total": 15965664, + "bytes_model_int6_zstd": 15915703, + "bytes_code": 49961, + "seeds": { + "1337": { + "val_bpb": 1.12513674, + "train_time_ms": 599779, + "bytes_total": 15965664 + }, + "7": { + "val_bpb": 1.12540132, + "train_time_ms": 599841, + "bytes_total": 15829190 + }, + "42": { + "val_bpb": 1.12431423, + "train_time_ms": 599822, + "bytes_total": 15806256 + } + }, + "mean_val_bpb": 1.12495076, + "stdev_val_bpb": 0.00056691 +} diff --git a/records/track_10min_16mb/2026-03-21_DominationV3/train.log b/records/track_10min_16mb/2026-03-21_DominationV3/train.log new file mode 100644 index 0000000000..c5a76f26ad --- /dev/null +++ b/records/track_10min_16mb/2026-03-21_DominationV3/train.log @@ -0,0 +1,1099 @@ +"""Domination V3: compact no-TTT path for 10-min track.""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + HAS_ZSTD = True +except ImportError: + HAS_ZSTD = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 4096)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.5)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.01)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 25)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) + ttt_p1_epochs = int(os.environ.get("TTT_P1_EPOCHS", 100)) + ttt_p1_lr = float(os.environ.get("TTT_P1_LR", 0.01)) + rope_dims = int(os.environ.get("ROPE_DIMS", 0)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "0"))) + +# ----------------------------- +# MUON OPTIMIZER WITH WEIGHT DECAY +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay)) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: continue + lr, momentum = group["lr"], group["momentum"] + backend_steps, nesterov = group["backend_steps"], group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts(sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): base_bytes_np[token_id] = 1; continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): has_leading_space_np[token_id] = True; piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return (torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device)) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: raise ValueError(f"Validation split is too short for seq_len={seq_len}") + return tokens[: usable + 1] + +def eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut): + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError(f"VAL_BATCH_SIZE too small for world={world_size} accum={grad_accum_steps} seq={args.train_seq_len}") + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + vls = torch.zeros((), device=device, dtype=torch.float64) + vtc = torch.zeros((), device=device, dtype=torch.float64) + vbc = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for bs in range(seq_start, seq_end, local_batch_seqs): + be = min(bs + local_batch_seqs, seq_end) + rs, re = bs * args.train_seq_len, be * args.train_seq_len + 1 + local = val_tokens[rs:re].to(device=device, dtype=torch.int64, non_blocking=True) + x, y = local[:-1].reshape(-1, args.train_seq_len), local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + bl = model(x, y).detach() + bc = float(y.numel()) + vls += bl.to(torch.float64) * bc; vtc += bc + pi, ti = x.reshape(-1), y.reshape(-1) + tb = base_bytes_lut[ti].to(dtype=torch.int16) + tb += (has_leading_space_lut[ti] & ~is_boundary_token_lut[pi]).to(dtype=torch.int16) + vbc += tb.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(vls, op=dist.ReduceOp.SUM); dist.all_reduce(vtc, op=dist.ReduceOp.SUM); dist.all_reduce(vbc, op=dist.ReduceOp.SUM) + vl = vls / vtc; bpt = vl.item() / math.log(2.0); tpb = vtc.item() / vbc.item() + model.train(); return float(vl.item()), float(bpt * tpb) + +def eval_val_sliding(args, base_model, rank, world_size, device, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride, batch_seqs=32): + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) if min(ws + seq_len, total_tokens) - ws >= stride] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + ls = torch.zeros((), device=device, dtype=torch.float64) + tc = torch.zeros((), device=device, dtype=torch.float64) + bc = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + for bi in range(0, len(my_windows), batch_seqs): + bw = my_windows[bi:bi + batch_seqs] + bsz = len(bw) + xb = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + yb = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + for i, ws in enumerate(bw): + ep = min(ws + seq_len, total_tokens) + wl = ep - ws + wlens.append(wl) + ch = val_tokens[ws:ep + 1].to(dtype=torch.int64, device=device) + xb[i, :wl] = ch[:-1] + yb[i, :wl] = ch[1:] + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(xb) + nll = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), yb.reshape(-1), reduction="none").reshape(bsz, seq_len) + for i, ws in enumerate(bw): + wl = wlens[i] + s = 0 if ws == 0 else wl - stride + sn = nll[i, s:wl].to(torch.float64) + ls += sn.sum() + tc += float(wl - s) + tgt, prev = yb[i, s:wl], xb[i, s:wl] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + bc += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + rb = 0.0 + if tc.item() > 0: + rl = (ls / tc).item() + rb = rl / math.log(2.0) * (tc.item() / bc.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={rb:.6f}", flush=True) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(ls, op=dist.ReduceOp.SUM); dist.all_reduce(tc, op=dist.ReduceOp.SUM); dist.all_reduce(bc, op=dist.ReduceOp.SUM) + vl = (ls / tc).item(); bpt = vl / math.log(2.0); tpb = tc.item() / bc.item() + base_model.train(); return vl, bpt * tpb + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple(p for p in os.environ.get("CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale").split(",") if p) + +INT6_QUANT_RANGE = 31 +MIXED_QUANT_INT6_CATS = frozenset( + c.strip() for c in os.environ.get("MIXED_QUANT_INT6_CATS", "mlp,attn").split(",") if c.strip() +) +STE_QAT_ENABLED = bool(int(os.environ.get("STE_QAT_ENABLED", "0"))) +STE_QAT_RANGE = int(os.environ.get("STE_QAT_RANGE", INT6_QUANT_RANGE)) +FP16_PASSTHROUGH_PATTERNS = tuple( + p.strip() for p in os.environ.get("FP16_PASSTHROUGH_PATTERNS", "").split(",") if p.strip() +) + +def _classify_param(name): + if "tok_emb" in name or "lm_head" in name: return "embed" + if ".mlp." in name: return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): return "attn" + return "other" + +def _get_ste_range_for_param(name): + return STE_QAT_RANGE + +CLIP_PCTS = [0.999, 0.9995, 0.9999, 0.99999, 1.0] + +def quantize_int6_per_row(t): + t32 = t.float() + if t32.ndim == 2: + best_q, best_sc, best_mse = None, None, float("inf") + for pct in CLIP_PCTS: + rm = torch.quantile(t32.abs(), pct, dim=1) + sc = (rm / 31.0).clamp_min(1.0 / 31.0).to(torch.float16) + tc = torch.clamp(t32, -rm[:, None], rm[:, None]) + q = torch.clamp(torch.round(tc / sc.float()[:, None]), -32, 31).to(torch.int8) + recon = q.float() * sc.float()[:, None] + mse = (t32 - recon).square().mean().item() + if mse < best_mse: best_mse = mse; best_q = q; best_sc = sc + return best_q, best_sc + am = t32.abs().max().item() + sc = torch.tensor(am / 31.0 if am > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / sc.float()), -32, 31).to(torch.int8) + return q, sc + +def quantize_int8_per_row(t): + t32 = t.float() + if t32.ndim == 2: + rm = t32.abs().amax(dim=1) + sc = (rm / 127.0).clamp_min(1e-8).to(torch.float16) + q = torch.clamp(torch.round(t32 / sc.float()[:, None]), -127, 127).to(torch.int8) + return q, sc + am = t32.abs().max().item() + sc = torch.tensor(am / 127.0 if am > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / sc.float()), -127, 127).to(torch.int8) + return q, sc + +def mixed_quantize(state_dict, int6_cats): + result, meta = {}, {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if t.is_floating_point() and FP16_PASSTHROUGH_PATTERNS and any(p in name for p in FP16_PASSTHROUGH_PATTERNS): + result[name] = t.to(torch.float16) + meta[name] = "passthrough_fp16" + continue + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q; result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_int8_per_row(t) + result[name + ".q"] = q; result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed(result, meta, template_sd): + out = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: continue + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig.dtype in (torch.float32, torch.bfloat16): + t = t.to(orig.dtype) + out[name] = t; continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig.dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig.dtype) + return out + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file): + hb = 256 * np.dtype(" 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: self._advance_file(); continue + k = min(remaining, avail); chunks.append(self.tokens[self.pos:self.pos + k]); self.pos += k; remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern, rank, world_size, device): + self.rank, self.world_size, self.device, self.stream = rank, world_size, device, TokenStream(pattern) + def next_batch(self, global_tokens, seq_len, grad_accum_steps): + lt = global_tokens // (self.world_size * grad_accum_steps); prs = lt + 1 + chunk = self.stream.take(prs * self.world_size); s = self.rank * prs + local = chunk[s:s + prs].to(dtype=torch.int64) + x, y = local[:-1].reshape(-1, seq_len), local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps=None): super().__init__(); self.eps = eps + def forward(self, x): return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +class CastedLinear(nn.Linear): + def forward(self, x): + w = self.weight.to(x.dtype) + if self.training and STE_QAT_ENABLED and w.ndim == 2: + with torch.no_grad(): + w32 = w.float(); rm = w32.abs().amax(dim=1).clamp_min(1e-8) + sc = rm / 31.0; wc = torch.clamp(w32, -rm[:, None], rm[:, None]) + wq = (torch.round(wc / sc[:, None]) * sc[:, None]).to(x.dtype) + w = w + (wq - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + +def restore_low_dim_params_to_fp32(module): + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + +class Rotary(nn.Module): + def __init__(self, dim, base=10000.0, train_seq_len=1024, rope_dims=0): + super().__init__() + self._base = base; self._train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + rd = self.rope_dims + self.register_buffer("inv_freq", 1.0 / (base ** (torch.arange(0, rd, 2, dtype=torch.float32) / rd)), persistent=False) + self._seq_len_cached = 0; self._cos_cached = None; self._sin_cached = None + def forward(self, seq_len, device, dtype): + if self._cos_cached is None or self._seq_len_cached != seq_len or self._cos_cached.device != device: + rd = self.rope_dims + if seq_len > self._train_seq_len: + scale = seq_len / self._train_seq_len + new_base = self._base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, None, :, :]; self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + +def apply_rotary_emb(x, cos, sin, rope_dims=0): + if rope_dims > 0 and rope_dims < x.size(-1): + xr, xp = x[..., :rope_dims], x[..., rope_dims:] + h = rope_dims // 2; x1, x2 = xr[..., :h], xr[..., h:] + xr = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((xr, xp), dim=-1) + h = x.size(-1) // 2; x1, x2 = x[..., :h], x[..., h:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_dims=0): + super().__init__() + self.num_heads, self.num_kv_heads = num_heads, num_kv_heads; self.head_dim = dim // num_heads + kv_dim = num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False); self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False); self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = rope_dims + self.rotary = Rotary(self.head_dim, base=rope_base, rope_dims=rope_dims) + self.use_xsa = False + + def _xsa_efficient(self, y, v): + B, T, H, D = y.shape + Hkv = v.size(2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x): + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True, enable_gqa=(self.num_kv_heads != self.num_heads)) + y = y.transpose(1, 2).contiguous() + if self.use_xsa: + v_bthd = v.transpose(1, 2).contiguous() + y = self._xsa_efficient(y, v_bthd) + return self.proj(y.reshape(bsz, seqlen, dim)) + +class MLP(nn.Module): + def __init__(self, dim, mlp_mult): + super().__init__(); hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False); self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x): return self.proj(torch.relu(self.fc(x)).square()) + +class SmearGate(nn.Module): + """Per-dimension SmearGate (from PR #194): each dim has its own blend ratio.""" + def __init__(self, dim): + super().__init__(); self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x): + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size, bigram_dim, model_dim): + super().__init__(); self.bvs = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim); nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def forward(self, token_ids): + t = token_ids.to(torch.int32); mod = self.bvs - 1 + out = torch.empty_like(t); out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + h = self.embed(out.long()) + if self.proj is not None: h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class Block(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, rope_dims=0, ln_sf=1.0): + super().__init__() + self.attn_norm, self.mlp_norm = RMSNorm(), RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_dims=rope_dims) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_sf = ln_sf + def forward(self, x, x0): + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + s = self.ln_sf + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * self.attn(self.attn_norm(x) * s) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * s) + return x + +class GPT(nn.Module): + def __init__(self, vocab_size, num_layers, model_dim, num_heads, num_kv_heads, mlp_mult, + tie_embeddings, tied_embed_init_std, logit_softcap, rope_base, qk_gain_init, + bigram_vocab_size=0, bigram_dim=128, xsa_last_n=0, rope_dims=0, ln_scale=False): + super().__init__() + self.tie_embeddings, self.tied_embed_init_std, self.logit_softcap = tie_embeddings, tied_embed_init_std, logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, rope_dims=rope_dims, ln_sf=1.0 / math.sqrt(i + 1) if ln_scale else 1.0) for i in range(num_layers)]) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: self.lm_head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + + def _init_weights(self): + if self.tie_embeddings: nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + nl = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): module.weight.mul_(1.0 / math.sqrt(2 * nl)) + + def _run_blocks(self, x, x0): + skips = [] + for i in range(self.num_encoder_layers): x = self.blocks[i](x, x0); skips.append(x) + for i in range(self.num_decoder_layers): + if skips: x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + return x + + def _embed(self, input_ids): + x = self.tok_emb(input_ids) + if self.bigram is not None: x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + return self.smear(x) + + def _logits(self, x): + if self.tie_embeddings: lp = F.linear(x, self.tok_emb.weight) + else: lp = self.lm_head(x) + return self.logit_softcap * torch.tanh(lp / self.logit_softcap) + + def forward(self, input_ids, target_ids): + x0 = self._embed(input_ids) + x = self.final_norm(self._run_blocks(x0, x0)) + x_flat = x.reshape(-1, x.size(-1)) + return F.cross_entropy(self._logits(x_flat).float(), target_ids.reshape(-1), reduction="mean") + + def forward_logits(self, input_ids): + x0 = self._embed(input_ids) + return self._logits(self.final_norm(self._run_blocks(x0, x0))) + +# ----------------------------- +# TRAINING +# ----------------------------- + +def _ttt_run(mdl, opt, epochs, rank, world_size, device, val_tokens, sl, batch_seqs): + nt = val_tokens.numel() - 1; ts = nt // sl + ms, me = (ts * rank) // world_size, (ts * (rank + 1)) // world_size + mdl.train() + for _ in range(epochs): + for bs in range(ms, me, batch_seqs): + be = min(bs + batch_seqs, me) + rs, re = bs * sl, be * sl + 1 + loc = val_tokens[rs:re].to(device=device, dtype=torch.int64, non_blocking=True) + x, y = loc[:-1].reshape(-1, sl), loc[1:].reshape(-1, sl) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): loss = mdl(x, y) + opt.zero_grad(set_to_none=True); loss.backward(); opt.step() + if dist.is_available() and dist.is_initialized(): dist.barrier() + +def ttt_adapt(args, mdl, rank, world_size, device, val_tokens, log0): + sl = args.train_seq_len + for p in mdl.parameters(): p.requires_grad_(False) + if args.ttt_p1_epochs > 0: + norm_params = [] + for n, p in mdl.named_parameters(): + if "norm" in n or "attn_scale" in n or "mlp_scale" in n or "resid_mix" in n or "skip_weight" in n: + p.requires_grad_(True); norm_params.append(p) + log0(f"ttt_p1: {args.ttt_p1_epochs}ep Adam lr={args.ttt_p1_lr} params={sum(p.numel() for p in norm_params)}") + opt1 = torch.optim.Adam(norm_params, lr=args.ttt_p1_lr) + torch.cuda.synchronize(); t0 = time.perf_counter() + _ttt_run(mdl, opt1, args.ttt_p1_epochs, rank, world_size, device, val_tokens, sl, args.ttt_batch_seqs) + torch.cuda.synchronize(); log0(f"ttt_p1: done in {1000.0 * (time.perf_counter() - t0):.0f}ms") + for p in norm_params: p.requires_grad_(False) + del opt1 + if args.ttt_epochs > 0: + nl = len(mdl.blocks) + for i, b in enumerate(mdl.blocks): + req = i >= args.ttt_freeze_blocks + for p in b.parameters(): p.requires_grad_(req) + trainable = [p for p in mdl.parameters() if p.requires_grad] + log0(f"ttt_p2: {args.ttt_epochs}ep SGD lr={args.ttt_lr} freeze={args.ttt_freeze_blocks} params={sum(p.numel() for p in trainable)}") + opt2 = torch.optim.SGD(trainable, lr=args.ttt_lr, momentum=args.ttt_momentum) + torch.cuda.synchronize(); t0 = time.perf_counter() + _ttt_run(mdl, opt2, args.ttt_epochs, rank, world_size, device, val_tokens, sl, args.ttt_batch_seqs) + torch.cuda.synchronize(); log0(f"ttt_p2: done in {1000.0 * (time.perf_counter() - t0):.0f}ms") + del opt2 + for p in mdl.parameters(): p.requires_grad_(False) + mdl.eval() + +def main(): + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8"); args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank, world_size = int(os.environ.get("RANK", "0")), int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + grad_accum_steps = 8 // world_size; grad_scale = 1.0 / grad_accum_steps + device = torch.device("cuda", local_rank); torch.cuda.set_device(device) + if distributed: dist.init_process_group(backend="nccl", device_id=device); dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True; torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False); enable_flash_sdp(True); enable_mem_efficient_sdp(False); enable_math_sdp(False) + logfile = None + if master_process: os.makedirs("logs", exist_ok=True); logfile = f"logs/{args.run_id}.txt"; print(logfile) + def log0(msg, console=True): + if not master_process: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + log0(code, console=False); log0("=" * 100, console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + log0("=" * 100, console=False) + random.seed(args.seed); np.random.seed(args.seed); torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:tokens:{val_tokens.numel() - 1}") + + base_model = GPT(vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, rope_dims=args.rope_dims, ln_scale=args.ln_scale).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [p for n, p in block_named_params if p.ndim == 2 and not any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = [p for n, p in block_named_params if p.ndim < 2 or any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + if base_model.skip_weights.numel() > 0: scalar_params.append(base_model.skip_weights) + for p in base_model.smear.parameters(): scalar_params.append(p) + if base_model.bigram is not None: + for n, p in base_model.bigram.named_parameters(): + if p.ndim == 2 and p.shape[0] >= 64 and p.shape[1] >= 64: matrix_params.append(p) + else: scalar_params.append(p) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.AdamW([{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=args.weight_decay, fused=True) + optimizer_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, weight_decay=args.muon_wd) + for group in optimizer_muon.param_groups: group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW([{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=args.weight_decay, fused=True) + optimizers = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizers.insert(1, torch.optim.AdamW([{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=args.weight_decay, fused=True)) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} swa:{args.swa_enabled} compression:{'zstd-22' if HAS_ZSTD else 'zlib-9'}") + log0(f"bigram_vocab:{args.bigram_vocab_size} bigram_dim:{args.bigram_dim} grad_clip:{args.grad_clip_norm} muon_wd:{args.muon_wd}") + log0(f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} warmdown:{args.warmdown_iters} seed:{args.seed}") + log0(f"ste_qat:{STE_QAT_ENABLED} ste_range:{STE_QAT_RANGE} int6_cats:{MIXED_QUANT_INT6_CATS}") + log0(f"fp16_passthrough:{FP16_PASSTHROUGH_PATTERNS}") + log0(f"xsa_last_n:{args.xsa_last_n} ema:{args.ema_enabled} ema_decay:{args.ema_decay} rope_dims:{args.rope_dims} ln_scale:{args.ln_scale}") + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all(): + for opt in optimizers: opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step, elapsed_ms): + if args.warmdown_iters <= 0: return 1.0 + if max_wallclock_ms is None: + ws = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if ws <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1); wms = args.warmdown_iters * step_ms + rms = max(max_wallclock_ms - elapsed_ms, 0.0) + return rms / max(wms, 1e-9) if rms <= wms else 1.0 + + if args.warmup_steps > 0: + init_sd = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers]; model.train() + for ws in range(args.warmup_steps): + zero_grad_all() + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = ms == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): wl = model(x, y) + (wl * grad_scale).backward() + for o in optimizers: o.step() + zero_grad_all() + if args.warmup_steps <= 20 or (ws + 1) % 10 == 0: log0(f"warmup_step:{ws + 1}/{args.warmup_steps}") + base_model.load_state_dict(init_sd, strict=True) + for o, s in zip(optimizers, init_opts, strict=True): o.load_state_dict(s) + zero_grad_all() + if distributed: model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + training_time_ms = 0.0; stop_after_step = None; swa_state = None; swa_count = 0 + ema_state = None + if args.ema_enabled: + ema_state = {n: t.detach().float().clone() for n, t in base_model.state_dict().items()} + torch.cuda.synchronize(); t0 = time.perf_counter(); step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + torch.cuda.synchronize(); training_time_ms += 1000.0 * (time.perf_counter() - t0) + vl, vb = eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vb:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms") + torch.cuda.synchronize(); t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms step:{step}/{args.iterations}") + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0); scale = lr_mul(step, elapsed_ms) + if args.swa_enabled and not args.ema_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + current = {k: v.detach().cpu().clone() for k, v in base_model.state_dict().items()} + if swa_state is None: + swa_state = current; swa_count = 1 + else: + inv = 1.0 / (swa_count + 1); keep = 1.0 - inv + for k, t in current.items(): + if torch.is_floating_point(swa_state[k]): swa_state[k].mul_(keep).add_(t, alpha=inv) + else: swa_state[k] = t + swa_count += 1 + zero_grad_all(); train_loss = torch.zeros((), device=device) + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = ms == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): loss = model(x, y) + train_loss += loss.detach(); (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + for group in optimizer_muon.param_groups: group["momentum"] = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for opt in optimizers: + for group in opt.param_groups: group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: opt.step() + if ema_state is not None: + d = args.ema_decay + with torch.no_grad(): + for n, t in base_model.state_dict().items(): + ema_state[n].mul_(d).add_(t.detach().float(), alpha=1.0 - d) + zero_grad_all(); step += 1 + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None): + log0(f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} train_time:{approx_ms:.0f}ms step_avg:{approx_ms / step:.2f}ms") + reached_cap = max_wallclock_ms is not None and approx_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + rct = torch.tensor(int(reached_cap), device=device); dist.all_reduce(rct, op=dist.ReduceOp.MAX); reached_cap = bool(rct.item()) + if stop_after_step is None and reached_cap: stop_after_step = step + + log0(f"peak memory: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") + if ema_state is not None: + log0("ema: applying EMA weights") + avg_sd = {n: t.to(dtype=base_model.state_dict()[n].dtype) for n, t in ema_state.items()} + base_model.load_state_dict(avg_sd, strict=True); del ema_state, avg_sd + elif swa_state is not None: + log0(f"swa: averaging {swa_count} checkpoints") + base_model.load_state_dict(swa_state, strict=True); del swa_state + + export_sd = base_model.state_dict() + if master_process: + torch.save(export_sd, "final_model.pt") + log0(f"Serialized model: {os.path.getsize('final_model.pt')} bytes Code: {len(code.encode('utf-8'))} bytes") + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + quant_result, quant_meta = mixed_quantize(sd_cpu, MIXED_QUANT_INT6_CATS) + qbuf = io.BytesIO(); torch.save({"w": quant_result, "m": quant_meta}, qbuf); qraw = qbuf.getvalue() + if HAS_ZSTD: qblob = zstd.ZstdCompressor(level=22).compress(qraw); cl = "zstd-22" + else: qblob = zlib.compress(qraw, level=9); cl = "zlib-9" + if master_process: + with open("final_model.int8.ptz", "wb") as f: f.write(qblob) + qfb = len(qblob); cb = len(code.encode("utf-8")) + log0(f"final_int8_zlib_roundtrip compressed_model_bytes:{qfb} code_bytes:{cb} total_artifact_bytes:{qfb + cb}") + log0(f"Serialized {cl}: {qfb} bytes Total: {qfb + cb} bytes") + if distributed: dist.barrier() + with open("final_model.int8.ptz", "rb") as f: qbd = f.read() + rd = zstd.ZstdDecompressor().decompress(qbd) if HAS_ZSTD else zlib.decompress(qbd) + qs = torch.load(io.BytesIO(rd), map_location="cpu") + deq_sd = dequantize_mixed(qs["w"], qs["m"], sd_cpu) + eval_model = GPT(vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, rope_dims=args.rope_dims, ln_scale=args.ln_scale).to(device).bfloat16() + for module in eval_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_sd, strict=True) + + if args.ttt_enabled: + log0(f"ttt: p1={args.ttt_p1_epochs}ep p2={args.ttt_epochs}ep freeze={args.ttt_freeze_blocks}") + torch.cuda.synchronize(); t_ttt = time.perf_counter() + ttt_adapt(args, eval_model, rank, world_size, device, val_tokens, log0) + torch.cuda.synchronize(); log0(f"ttt: total {1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + + torch.cuda.synchronize(); tqe = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + qvl, qvb = eval_val_sliding(args, eval_model, rank, world_size, device, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, stride=args.eval_stride, batch_seqs=args.eval_batch_seqs) + else: + qvl, qvb = eval_val(args, eval_model, rank, world_size, device, grad_accum_steps, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + torch.cuda.synchronize() + log0(f"final_roundtrip val_loss:{qvl:.4f} val_bpb:{qvb:.4f} eval_time:{1000.0 * (time.perf_counter() - tqe):.0f}ms") + log0(f"final_roundtrip_exact val_loss:{qvl:.8f} val_bpb:{qvb:.8f}") + if distributed: dist.destroy_process_group() + +if __name__ == "__main__": + main() + +==================================================================================================== +Sun Mar 22 15:19:17 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.95.05 Driver Version: 580.95.05 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:04:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | 0 | +| N/A 29C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | +| N/A 32C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:0B:00.0 Off | 0 | +| N/A 30C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:84:00.0 Off | 0 | +| N/A 32C P0 114W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:85:00.0 Off | 0 | +| N/A 31C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:8A:00.0 Off | 0 | +| N/A 31C P0 112W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:8B:00.0 Off | 0 | +| N/A 31C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 1 C /bin/dumb-init 1512MiB | +| 1 N/A N/A 1 C /bin/dumb-init 1512MiB | +| 2 N/A N/A 1 C /bin/dumb-init 1512MiB | +| 3 N/A N/A 1 C /bin/dumb-init 1512MiB | +| 4 N/A N/A 1 C /bin/dumb-init 1512MiB | +| 5 N/A N/A 1 C /bin/dumb-init 1512MiB | +| 6 N/A N/A 1 C /bin/dumb-init 1512MiB | +| 7 N/A N/A 1 C /bin/dumb-init 1512MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:tokens:62021632 +model_params:27092057 swa:False compression:zstd-22 +bigram_vocab:4096 bigram_dim:128 grad_clip:0.3 muon_wd:0.04 +train_batch_tokens:524288 train_seq_len:2048 warmdown:3000 seed:1337 +ste_qat:False ste_range:31 int6_cats:frozenset({'mlp', 'tok_emb', 'attn'}) +fp16_passthrough:() +xsa_last_n:0 ema:True ema_decay:0.997 rope_dims:16 ln_scale:True +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9285 val_bpb:4.1034 train_time:0ms step_avg:0.04ms +step:1/20000 train_loss:6.9290 train_time:231ms step_avg:231.17ms +step:2/20000 train_loss:8.6034 train_time:302ms step_avg:150.80ms +step:3/20000 train_loss:7.8180 train_time:375ms step_avg:124.90ms +step:4/20000 train_loss:7.2145 train_time:446ms step_avg:111.62ms +step:5/20000 train_loss:6.8782 train_time:519ms step_avg:103.86ms +step:6/20000 train_loss:7.7911 train_time:592ms step_avg:98.70ms +step:7/20000 train_loss:6.8352 train_time:670ms step_avg:95.76ms +step:8/20000 train_loss:6.6629 train_time:739ms step_avg:92.39ms +step:9/20000 train_loss:6.3942 train_time:813ms step_avg:90.29ms +step:10/20000 train_loss:6.2415 train_time:887ms step_avg:88.74ms +step:100/20000 train_loss:3.3031 train_time:7408ms step_avg:74.08ms +step:200/20000 train_loss:2.7785 train_time:14784ms step_avg:73.92ms +step:300/20000 train_loss:2.4071 train_time:22026ms step_avg:73.42ms +step:400/20000 train_loss:2.2889 train_time:29407ms step_avg:73.52ms +step:500/20000 train_loss:2.4348 train_time:36650ms step_avg:73.30ms +step:500/20000 val_loss:2.4254 val_bpb:1.4365 train_time:36655ms step_avg:73.31ms +step:600/20000 train_loss:2.4809 train_time:43968ms step_avg:73.28ms +step:700/20000 train_loss:2.3824 train_time:51184ms step_avg:73.12ms +step:800/20000 train_loss:2.2299 train_time:58557ms step_avg:73.20ms +step:900/20000 train_loss:2.2835 train_time:65822ms step_avg:73.14ms +step:1000/20000 train_loss:2.3277 train_time:73193ms step_avg:73.19ms +step:1000/20000 val_loss:2.2804 val_bpb:1.3506 train_time:73197ms step_avg:73.20ms +step:1100/20000 train_loss:2.1974 train_time:80566ms step_avg:73.24ms +step:1200/20000 train_loss:2.3508 train_time:87917ms step_avg:73.26ms +step:1300/20000 train_loss:2.3237 train_time:95130ms step_avg:73.18ms +step:1400/20000 train_loss:2.3794 train_time:102481ms step_avg:73.20ms +step:1500/20000 train_loss:2.1877 train_time:109716ms step_avg:73.14ms +step:1500/20000 val_loss:2.2314 val_bpb:1.3216 train_time:109723ms step_avg:73.15ms +step:1600/20000 train_loss:2.0520 train_time:117060ms step_avg:73.16ms +step:1700/20000 train_loss:2.1227 train_time:124288ms step_avg:73.11ms +step:1800/20000 train_loss:2.1539 train_time:131578ms step_avg:73.10ms +step:1900/20000 train_loss:2.1357 train_time:138794ms step_avg:73.05ms +step:2000/20000 train_loss:2.1930 train_time:146164ms step_avg:73.08ms +step:2000/20000 val_loss:2.1737 val_bpb:1.2874 train_time:146168ms step_avg:73.08ms +step:2100/20000 train_loss:2.2101 train_time:153525ms step_avg:73.11ms +step:2200/20000 train_loss:2.0086 train_time:160755ms step_avg:73.07ms +step:2300/20000 train_loss:2.3094 train_time:168130ms step_avg:73.10ms +step:2400/20000 train_loss:2.1364 train_time:175365ms step_avg:73.07ms +step:2500/20000 train_loss:2.0689 train_time:182710ms step_avg:73.08ms +step:2500/20000 val_loss:2.1441 val_bpb:1.2699 train_time:182714ms step_avg:73.09ms +step:2600/20000 train_loss:2.3659 train_time:189912ms step_avg:73.04ms +step:2700/20000 train_loss:2.0884 train_time:197244ms step_avg:73.05ms +step:2800/20000 train_loss:2.1723 train_time:204468ms step_avg:73.02ms +step:2900/20000 train_loss:2.1182 train_time:211824ms step_avg:73.04ms +step:3000/20000 train_loss:2.1634 train_time:219055ms step_avg:73.02ms +step:3000/20000 val_loss:2.1300 val_bpb:1.2615 train_time:219062ms step_avg:73.02ms +step:3100/20000 train_loss:2.1382 train_time:226345ms step_avg:73.01ms +step:3200/20000 train_loss:2.1298 train_time:233524ms step_avg:72.98ms +step:3300/20000 train_loss:2.1778 train_time:240858ms step_avg:72.99ms +step:3400/20000 train_loss:2.1022 train_time:248059ms step_avg:72.96ms +step:3500/20000 train_loss:2.1922 train_time:255397ms step_avg:72.97ms +step:3500/20000 val_loss:2.1197 val_bpb:1.2554 train_time:255401ms step_avg:72.97ms +step:3600/20000 train_loss:2.0390 train_time:262594ms step_avg:72.94ms +step:3700/20000 train_loss:2.0750 train_time:269901ms step_avg:72.95ms +step:3800/20000 train_loss:2.1509 train_time:277097ms step_avg:72.92ms +step:3900/20000 train_loss:1.9326 train_time:284478ms step_avg:72.94ms +step:4000/20000 train_loss:2.1173 train_time:291684ms step_avg:72.92ms +step:4000/20000 val_loss:2.1091 val_bpb:1.2491 train_time:291688ms step_avg:72.92ms +step:4100/20000 train_loss:2.1289 train_time:299036ms step_avg:72.94ms +step:4200/20000 train_loss:2.1121 train_time:306374ms step_avg:72.95ms +step:4300/20000 train_loss:1.9553 train_time:313578ms step_avg:72.93ms +step:4400/20000 train_loss:2.0537 train_time:320851ms step_avg:72.92ms +step:4500/20000 train_loss:2.2022 train_time:328060ms step_avg:72.90ms +step:4500/20000 val_loss:2.1044 val_bpb:1.2464 train_time:328063ms step_avg:72.90ms +step:4600/20000 train_loss:1.9143 train_time:335384ms step_avg:72.91ms +step:4700/20000 train_loss:2.2173 train_time:342561ms step_avg:72.89ms +step:4800/20000 train_loss:2.2023 train_time:349896ms step_avg:72.90ms +step:4900/20000 train_loss:2.1133 train_time:357092ms step_avg:72.88ms +step:5000/20000 train_loss:1.9627 train_time:364416ms step_avg:72.88ms +step:5000/20000 val_loss:2.0993 val_bpb:1.2433 train_time:364423ms step_avg:72.88ms +step:5100/20000 train_loss:1.9773 train_time:371646ms step_avg:72.87ms +step:5200/20000 train_loss:2.1248 train_time:378960ms step_avg:72.88ms +step:5300/20000 train_loss:2.1537 train_time:386157ms step_avg:72.86ms +step:5400/20000 train_loss:2.1398 train_time:393471ms step_avg:72.87ms +step:5500/20000 train_loss:2.0911 train_time:400701ms step_avg:72.85ms +step:5500/20000 val_loss:2.0900 val_bpb:1.2378 train_time:400705ms step_avg:72.86ms +step:5600/20000 train_loss:2.1187 train_time:408019ms step_avg:72.86ms +step:5700/20000 train_loss:2.1094 train_time:415242ms step_avg:72.85ms +step:5800/20000 train_loss:2.0660 train_time:422549ms step_avg:72.85ms +step:5900/20000 train_loss:2.0258 train_time:429720ms step_avg:72.83ms +step:6000/20000 train_loss:2.1443 train_time:437054ms step_avg:72.84ms +step:6000/20000 val_loss:2.0676 val_bpb:1.2245 train_time:437060ms step_avg:72.84ms +step:6100/20000 train_loss:2.0433 train_time:444266ms step_avg:72.83ms +step:6200/20000 train_loss:2.0075 train_time:451623ms step_avg:72.84ms +step:6300/20000 train_loss:1.9513 train_time:458933ms step_avg:72.85ms +step:6400/20000 train_loss:2.0794 train_time:466124ms step_avg:72.83ms +step:6500/20000 train_loss:1.9901 train_time:473465ms step_avg:72.84ms +step:6500/20000 val_loss:2.0449 val_bpb:1.2111 train_time:473469ms step_avg:72.84ms +step:6600/20000 train_loss:2.0228 train_time:480655ms step_avg:72.83ms +step:6700/20000 train_loss:2.0583 train_time:487972ms step_avg:72.83ms +step:6800/20000 train_loss:2.0765 train_time:495191ms step_avg:72.82ms +step:6900/20000 train_loss:1.9937 train_time:502525ms step_avg:72.83ms +step:7000/20000 train_loss:2.1153 train_time:509716ms step_avg:72.82ms +step:7000/20000 val_loss:2.0182 val_bpb:1.1953 train_time:509720ms step_avg:72.82ms +step:7100/20000 train_loss:1.9488 train_time:517040ms step_avg:72.82ms +step:7200/20000 train_loss:2.0761 train_time:524236ms step_avg:72.81ms +step:7300/20000 train_loss:1.9711 train_time:531577ms step_avg:72.82ms +step:7400/20000 train_loss:1.9931 train_time:538777ms step_avg:72.81ms +step:7500/20000 train_loss:1.9831 train_time:546090ms step_avg:72.81ms +step:7500/20000 val_loss:1.9879 val_bpb:1.1773 train_time:546094ms step_avg:72.81ms +step:7600/20000 train_loss:1.8564 train_time:553291ms step_avg:72.80ms +step:7700/20000 train_loss:1.9353 train_time:560626ms step_avg:72.81ms +step:7800/20000 train_loss:1.9930 train_time:567823ms step_avg:72.80ms +step:7900/20000 train_loss:1.9685 train_time:575210ms step_avg:72.81ms +step:8000/20000 train_loss:1.9578 train_time:582413ms step_avg:72.80ms +step:8000/20000 val_loss:1.9552 val_bpb:1.1580 train_time:582417ms step_avg:72.80ms +step:8100/20000 train_loss:1.9856 train_time:589751ms step_avg:72.81ms +step:8200/20000 train_loss:2.0185 train_time:596905ms step_avg:72.79ms +step:8238/20000 val_loss:1.9443 val_bpb:1.1515 train_time:599779ms step_avg:72.81ms +stopping_early: wallclock_cap train_time:599779ms step:8238/20000 +peak memory: 14059 MiB +ema: applying EMA weights +Serialized model: 106313663 bytes Code: 49961 bytes +final_int8_zlib_roundtrip compressed_model_bytes:15915703 code_bytes:49961 total_artifact_bytes:15965664 +Serialized zstd-22: 15915703 bytes Total: 15965664 bytes +ttt: p1=0ep p2=25ep freeze=0 +ttt_p2: 25ep SGD lr=0.012 freeze=0 params=25974872 +ttt_p2: done in 389129ms +ttt: total 389133ms +final_eval_mode:sliding_window stride:64 batch_seqs:32 +final_roundtrip val_loss:1.8997 val_bpb:1.1251 eval_time:187291ms +final_roundtrip_exact val_loss:1.89974122 val_bpb:1.12513674 diff --git a/records/track_10min_16mb/2026-03-21_DominationV3/train_gpt.py b/records/track_10min_16mb/2026-03-21_DominationV3/train_gpt.py new file mode 100644 index 0000000000..d12b3efd12 --- /dev/null +++ b/records/track_10min_16mb/2026-03-21_DominationV3/train_gpt.py @@ -0,0 +1,889 @@ +"""Domination V3: compact no-TTT path for 10-min track.""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + HAS_ZSTD = True +except ImportError: + HAS_ZSTD = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 4096)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.5)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.01)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 25)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) + ttt_p1_epochs = int(os.environ.get("TTT_P1_EPOCHS", 100)) + ttt_p1_lr = float(os.environ.get("TTT_P1_LR", 0.01)) + rope_dims = int(os.environ.get("ROPE_DIMS", 0)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "0"))) + +# ----------------------------- +# MUON OPTIMIZER WITH WEIGHT DECAY +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay)) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: continue + lr, momentum = group["lr"], group["momentum"] + backend_steps, nesterov = group["backend_steps"], group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts(sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): base_bytes_np[token_id] = 1; continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): has_leading_space_np[token_id] = True; piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return (torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device)) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: raise ValueError(f"Validation split is too short for seq_len={seq_len}") + return tokens[: usable + 1] + +def eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut): + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError(f"VAL_BATCH_SIZE too small for world={world_size} accum={grad_accum_steps} seq={args.train_seq_len}") + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + vls = torch.zeros((), device=device, dtype=torch.float64) + vtc = torch.zeros((), device=device, dtype=torch.float64) + vbc = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for bs in range(seq_start, seq_end, local_batch_seqs): + be = min(bs + local_batch_seqs, seq_end) + rs, re = bs * args.train_seq_len, be * args.train_seq_len + 1 + local = val_tokens[rs:re].to(device=device, dtype=torch.int64, non_blocking=True) + x, y = local[:-1].reshape(-1, args.train_seq_len), local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + bl = model(x, y).detach() + bc = float(y.numel()) + vls += bl.to(torch.float64) * bc; vtc += bc + pi, ti = x.reshape(-1), y.reshape(-1) + tb = base_bytes_lut[ti].to(dtype=torch.int16) + tb += (has_leading_space_lut[ti] & ~is_boundary_token_lut[pi]).to(dtype=torch.int16) + vbc += tb.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(vls, op=dist.ReduceOp.SUM); dist.all_reduce(vtc, op=dist.ReduceOp.SUM); dist.all_reduce(vbc, op=dist.ReduceOp.SUM) + vl = vls / vtc; bpt = vl.item() / math.log(2.0); tpb = vtc.item() / vbc.item() + model.train(); return float(vl.item()), float(bpt * tpb) + +def eval_val_sliding(args, base_model, rank, world_size, device, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride, batch_seqs=32): + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) if min(ws + seq_len, total_tokens) - ws >= stride] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + ls = torch.zeros((), device=device, dtype=torch.float64) + tc = torch.zeros((), device=device, dtype=torch.float64) + bc = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + for bi in range(0, len(my_windows), batch_seqs): + bw = my_windows[bi:bi + batch_seqs] + bsz = len(bw) + xb = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + yb = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + for i, ws in enumerate(bw): + ep = min(ws + seq_len, total_tokens) + wl = ep - ws + wlens.append(wl) + ch = val_tokens[ws:ep + 1].to(dtype=torch.int64, device=device) + xb[i, :wl] = ch[:-1] + yb[i, :wl] = ch[1:] + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(xb) + nll = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), yb.reshape(-1), reduction="none").reshape(bsz, seq_len) + for i, ws in enumerate(bw): + wl = wlens[i] + s = 0 if ws == 0 else wl - stride + sn = nll[i, s:wl].to(torch.float64) + ls += sn.sum() + tc += float(wl - s) + tgt, prev = yb[i, s:wl], xb[i, s:wl] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + bc += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + rb = 0.0 + if tc.item() > 0: + rl = (ls / tc).item() + rb = rl / math.log(2.0) * (tc.item() / bc.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={rb:.6f}", flush=True) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(ls, op=dist.ReduceOp.SUM); dist.all_reduce(tc, op=dist.ReduceOp.SUM); dist.all_reduce(bc, op=dist.ReduceOp.SUM) + vl = (ls / tc).item(); bpt = vl / math.log(2.0); tpb = tc.item() / bc.item() + base_model.train(); return vl, bpt * tpb + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple(p for p in os.environ.get("CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale").split(",") if p) + +INT6_QUANT_RANGE = 31 +MIXED_QUANT_INT6_CATS = frozenset( + c.strip() for c in os.environ.get("MIXED_QUANT_INT6_CATS", "mlp,attn").split(",") if c.strip() +) +STE_QAT_ENABLED = bool(int(os.environ.get("STE_QAT_ENABLED", "0"))) +STE_QAT_RANGE = int(os.environ.get("STE_QAT_RANGE", INT6_QUANT_RANGE)) +FP16_PASSTHROUGH_PATTERNS = tuple( + p.strip() for p in os.environ.get("FP16_PASSTHROUGH_PATTERNS", "").split(",") if p.strip() +) + +def _classify_param(name): + if "tok_emb" in name or "lm_head" in name: return "embed" + if ".mlp." in name: return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): return "attn" + return "other" + +def _get_ste_range_for_param(name): + return STE_QAT_RANGE + +CLIP_PCTS = [0.999, 0.9995, 0.9999, 0.99999, 1.0] + +def quantize_int6_per_row(t): + t32 = t.float() + if t32.ndim == 2: + best_q, best_sc, best_mse = None, None, float("inf") + for pct in CLIP_PCTS: + rm = torch.quantile(t32.abs(), pct, dim=1) + sc = (rm / 31.0).clamp_min(1.0 / 31.0).to(torch.float16) + tc = torch.clamp(t32, -rm[:, None], rm[:, None]) + q = torch.clamp(torch.round(tc / sc.float()[:, None]), -32, 31).to(torch.int8) + recon = q.float() * sc.float()[:, None] + mse = (t32 - recon).square().mean().item() + if mse < best_mse: best_mse = mse; best_q = q; best_sc = sc + return best_q, best_sc + am = t32.abs().max().item() + sc = torch.tensor(am / 31.0 if am > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / sc.float()), -32, 31).to(torch.int8) + return q, sc + +def quantize_int8_per_row(t): + t32 = t.float() + if t32.ndim == 2: + rm = t32.abs().amax(dim=1) + sc = (rm / 127.0).clamp_min(1e-8).to(torch.float16) + q = torch.clamp(torch.round(t32 / sc.float()[:, None]), -127, 127).to(torch.int8) + return q, sc + am = t32.abs().max().item() + sc = torch.tensor(am / 127.0 if am > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / sc.float()), -127, 127).to(torch.int8) + return q, sc + +def mixed_quantize(state_dict, int6_cats): + result, meta = {}, {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if t.is_floating_point() and FP16_PASSTHROUGH_PATTERNS and any(p in name for p in FP16_PASSTHROUGH_PATTERNS): + result[name] = t.to(torch.float16) + meta[name] = "passthrough_fp16" + continue + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q; result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_int8_per_row(t) + result[name + ".q"] = q; result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed(result, meta, template_sd): + out = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: continue + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig.dtype in (torch.float32, torch.bfloat16): + t = t.to(orig.dtype) + out[name] = t; continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig.dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig.dtype) + return out + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file): + hb = 256 * np.dtype(" 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: self._advance_file(); continue + k = min(remaining, avail); chunks.append(self.tokens[self.pos:self.pos + k]); self.pos += k; remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern, rank, world_size, device): + self.rank, self.world_size, self.device, self.stream = rank, world_size, device, TokenStream(pattern) + def next_batch(self, global_tokens, seq_len, grad_accum_steps): + lt = global_tokens // (self.world_size * grad_accum_steps); prs = lt + 1 + chunk = self.stream.take(prs * self.world_size); s = self.rank * prs + local = chunk[s:s + prs].to(dtype=torch.int64) + x, y = local[:-1].reshape(-1, seq_len), local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps=None): super().__init__(); self.eps = eps + def forward(self, x): return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +class CastedLinear(nn.Linear): + def forward(self, x): + w = self.weight.to(x.dtype) + if self.training and STE_QAT_ENABLED and w.ndim == 2: + with torch.no_grad(): + w32 = w.float(); rm = w32.abs().amax(dim=1).clamp_min(1e-8) + sc = rm / 31.0; wc = torch.clamp(w32, -rm[:, None], rm[:, None]) + wq = (torch.round(wc / sc[:, None]) * sc[:, None]).to(x.dtype) + w = w + (wq - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + +def restore_low_dim_params_to_fp32(module): + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + +class Rotary(nn.Module): + def __init__(self, dim, base=10000.0, train_seq_len=1024, rope_dims=0): + super().__init__() + self._base = base; self._train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + rd = self.rope_dims + self.register_buffer("inv_freq", 1.0 / (base ** (torch.arange(0, rd, 2, dtype=torch.float32) / rd)), persistent=False) + self._seq_len_cached = 0; self._cos_cached = None; self._sin_cached = None + def forward(self, seq_len, device, dtype): + if self._cos_cached is None or self._seq_len_cached != seq_len or self._cos_cached.device != device: + rd = self.rope_dims + if seq_len > self._train_seq_len: + scale = seq_len / self._train_seq_len + new_base = self._base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, None, :, :]; self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + +def apply_rotary_emb(x, cos, sin, rope_dims=0): + if rope_dims > 0 and rope_dims < x.size(-1): + xr, xp = x[..., :rope_dims], x[..., rope_dims:] + h = rope_dims // 2; x1, x2 = xr[..., :h], xr[..., h:] + xr = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((xr, xp), dim=-1) + h = x.size(-1) // 2; x1, x2 = x[..., :h], x[..., h:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_dims=0): + super().__init__() + self.num_heads, self.num_kv_heads = num_heads, num_kv_heads; self.head_dim = dim // num_heads + kv_dim = num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False); self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False); self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = rope_dims + self.rotary = Rotary(self.head_dim, base=rope_base, rope_dims=rope_dims) + self.use_xsa = False + + def _xsa_efficient(self, y, v): + B, T, H, D = y.shape + Hkv = v.size(2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x): + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True, enable_gqa=(self.num_kv_heads != self.num_heads)) + y = y.transpose(1, 2).contiguous() + if self.use_xsa: + v_bthd = v.transpose(1, 2).contiguous() + y = self._xsa_efficient(y, v_bthd) + return self.proj(y.reshape(bsz, seqlen, dim)) + +class MLP(nn.Module): + def __init__(self, dim, mlp_mult): + super().__init__(); hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False); self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x): return self.proj(torch.relu(self.fc(x)).square()) + +class SmearGate(nn.Module): + """Per-dimension SmearGate (from PR #194): each dim has its own blend ratio.""" + def __init__(self, dim): + super().__init__(); self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x): + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size, bigram_dim, model_dim): + super().__init__(); self.bvs = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim); nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def forward(self, token_ids): + t = token_ids.to(torch.int32); mod = self.bvs - 1 + out = torch.empty_like(t); out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + h = self.embed(out.long()) + if self.proj is not None: h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class Block(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, rope_dims=0, ln_sf=1.0): + super().__init__() + self.attn_norm, self.mlp_norm = RMSNorm(), RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_dims=rope_dims) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_sf = ln_sf + def forward(self, x, x0): + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + s = self.ln_sf + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * self.attn(self.attn_norm(x) * s) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * s) + return x + +class GPT(nn.Module): + def __init__(self, vocab_size, num_layers, model_dim, num_heads, num_kv_heads, mlp_mult, + tie_embeddings, tied_embed_init_std, logit_softcap, rope_base, qk_gain_init, + bigram_vocab_size=0, bigram_dim=128, xsa_last_n=0, rope_dims=0, ln_scale=False): + super().__init__() + self.tie_embeddings, self.tied_embed_init_std, self.logit_softcap = tie_embeddings, tied_embed_init_std, logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, rope_dims=rope_dims, ln_sf=1.0 / math.sqrt(i + 1) if ln_scale else 1.0) for i in range(num_layers)]) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: self.lm_head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + + def _init_weights(self): + if self.tie_embeddings: nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + nl = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): module.weight.mul_(1.0 / math.sqrt(2 * nl)) + + def _run_blocks(self, x, x0): + skips = [] + for i in range(self.num_encoder_layers): x = self.blocks[i](x, x0); skips.append(x) + for i in range(self.num_decoder_layers): + if skips: x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + return x + + def _embed(self, input_ids): + x = self.tok_emb(input_ids) + if self.bigram is not None: x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + return self.smear(x) + + def _logits(self, x): + if self.tie_embeddings: lp = F.linear(x, self.tok_emb.weight) + else: lp = self.lm_head(x) + return self.logit_softcap * torch.tanh(lp / self.logit_softcap) + + def forward(self, input_ids, target_ids): + x0 = self._embed(input_ids) + x = self.final_norm(self._run_blocks(x0, x0)) + x_flat = x.reshape(-1, x.size(-1)) + return F.cross_entropy(self._logits(x_flat).float(), target_ids.reshape(-1), reduction="mean") + + def forward_logits(self, input_ids): + x0 = self._embed(input_ids) + return self._logits(self.final_norm(self._run_blocks(x0, x0))) + +# ----------------------------- +# TRAINING +# ----------------------------- + +def _ttt_run(mdl, opt, epochs, rank, world_size, device, val_tokens, sl, batch_seqs): + nt = val_tokens.numel() - 1; ts = nt // sl + ms, me = (ts * rank) // world_size, (ts * (rank + 1)) // world_size + mdl.train() + for _ in range(epochs): + for bs in range(ms, me, batch_seqs): + be = min(bs + batch_seqs, me) + rs, re = bs * sl, be * sl + 1 + loc = val_tokens[rs:re].to(device=device, dtype=torch.int64, non_blocking=True) + x, y = loc[:-1].reshape(-1, sl), loc[1:].reshape(-1, sl) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): loss = mdl(x, y) + opt.zero_grad(set_to_none=True); loss.backward(); opt.step() + if dist.is_available() and dist.is_initialized(): dist.barrier() + +def ttt_adapt(args, mdl, rank, world_size, device, val_tokens, log0): + sl = args.train_seq_len + for p in mdl.parameters(): p.requires_grad_(False) + if args.ttt_p1_epochs > 0: + norm_params = [] + for n, p in mdl.named_parameters(): + if "norm" in n or "attn_scale" in n or "mlp_scale" in n or "resid_mix" in n or "skip_weight" in n: + p.requires_grad_(True); norm_params.append(p) + log0(f"ttt_p1: {args.ttt_p1_epochs}ep Adam lr={args.ttt_p1_lr} params={sum(p.numel() for p in norm_params)}") + opt1 = torch.optim.Adam(norm_params, lr=args.ttt_p1_lr) + torch.cuda.synchronize(); t0 = time.perf_counter() + _ttt_run(mdl, opt1, args.ttt_p1_epochs, rank, world_size, device, val_tokens, sl, args.ttt_batch_seqs) + torch.cuda.synchronize(); log0(f"ttt_p1: done in {1000.0 * (time.perf_counter() - t0):.0f}ms") + for p in norm_params: p.requires_grad_(False) + del opt1 + if args.ttt_epochs > 0: + nl = len(mdl.blocks) + for i, b in enumerate(mdl.blocks): + req = i >= args.ttt_freeze_blocks + for p in b.parameters(): p.requires_grad_(req) + trainable = [p for p in mdl.parameters() if p.requires_grad] + log0(f"ttt_p2: {args.ttt_epochs}ep SGD lr={args.ttt_lr} freeze={args.ttt_freeze_blocks} params={sum(p.numel() for p in trainable)}") + opt2 = torch.optim.SGD(trainable, lr=args.ttt_lr, momentum=args.ttt_momentum) + torch.cuda.synchronize(); t0 = time.perf_counter() + _ttt_run(mdl, opt2, args.ttt_epochs, rank, world_size, device, val_tokens, sl, args.ttt_batch_seqs) + torch.cuda.synchronize(); log0(f"ttt_p2: done in {1000.0 * (time.perf_counter() - t0):.0f}ms") + del opt2 + for p in mdl.parameters(): p.requires_grad_(False) + mdl.eval() + +def main(): + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8"); args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank, world_size = int(os.environ.get("RANK", "0")), int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + grad_accum_steps = 8 // world_size; grad_scale = 1.0 / grad_accum_steps + device = torch.device("cuda", local_rank); torch.cuda.set_device(device) + if distributed: dist.init_process_group(backend="nccl", device_id=device); dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True; torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False); enable_flash_sdp(True); enable_mem_efficient_sdp(False); enable_math_sdp(False) + logfile = None + if master_process: os.makedirs("logs", exist_ok=True); logfile = f"logs/{args.run_id}.txt"; print(logfile) + def log0(msg, console=True): + if not master_process: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + log0(code, console=False); log0("=" * 100, console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + log0("=" * 100, console=False) + random.seed(args.seed); np.random.seed(args.seed); torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:tokens:{val_tokens.numel() - 1}") + + base_model = GPT(vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, rope_dims=args.rope_dims, ln_scale=args.ln_scale).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [p for n, p in block_named_params if p.ndim == 2 and not any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = [p for n, p in block_named_params if p.ndim < 2 or any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + if base_model.skip_weights.numel() > 0: scalar_params.append(base_model.skip_weights) + for p in base_model.smear.parameters(): scalar_params.append(p) + if base_model.bigram is not None: + for n, p in base_model.bigram.named_parameters(): + if p.ndim == 2 and p.shape[0] >= 64 and p.shape[1] >= 64: matrix_params.append(p) + else: scalar_params.append(p) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.AdamW([{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=args.weight_decay, fused=True) + optimizer_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, weight_decay=args.muon_wd) + for group in optimizer_muon.param_groups: group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW([{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=args.weight_decay, fused=True) + optimizers = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizers.insert(1, torch.optim.AdamW([{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=args.weight_decay, fused=True)) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} swa:{args.swa_enabled} compression:{'zstd-22' if HAS_ZSTD else 'zlib-9'}") + log0(f"bigram_vocab:{args.bigram_vocab_size} bigram_dim:{args.bigram_dim} grad_clip:{args.grad_clip_norm} muon_wd:{args.muon_wd}") + log0(f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} warmdown:{args.warmdown_iters} seed:{args.seed}") + log0(f"ste_qat:{STE_QAT_ENABLED} ste_range:{STE_QAT_RANGE} int6_cats:{MIXED_QUANT_INT6_CATS}") + log0(f"fp16_passthrough:{FP16_PASSTHROUGH_PATTERNS}") + log0(f"xsa_last_n:{args.xsa_last_n} ema:{args.ema_enabled} ema_decay:{args.ema_decay} rope_dims:{args.rope_dims} ln_scale:{args.ln_scale}") + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all(): + for opt in optimizers: opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step, elapsed_ms): + if args.warmdown_iters <= 0: return 1.0 + if max_wallclock_ms is None: + ws = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if ws <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1); wms = args.warmdown_iters * step_ms + rms = max(max_wallclock_ms - elapsed_ms, 0.0) + return rms / max(wms, 1e-9) if rms <= wms else 1.0 + + if args.warmup_steps > 0: + init_sd = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers]; model.train() + for ws in range(args.warmup_steps): + zero_grad_all() + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = ms == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): wl = model(x, y) + (wl * grad_scale).backward() + for o in optimizers: o.step() + zero_grad_all() + if args.warmup_steps <= 20 or (ws + 1) % 10 == 0: log0(f"warmup_step:{ws + 1}/{args.warmup_steps}") + base_model.load_state_dict(init_sd, strict=True) + for o, s in zip(optimizers, init_opts, strict=True): o.load_state_dict(s) + zero_grad_all() + if distributed: model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + training_time_ms = 0.0; stop_after_step = None; swa_state = None; swa_count = 0 + ema_state = None + if args.ema_enabled: + ema_state = {n: t.detach().float().clone() for n, t in base_model.state_dict().items()} + torch.cuda.synchronize(); t0 = time.perf_counter(); step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + torch.cuda.synchronize(); training_time_ms += 1000.0 * (time.perf_counter() - t0) + vl, vb = eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vb:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms") + torch.cuda.synchronize(); t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms step:{step}/{args.iterations}") + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0); scale = lr_mul(step, elapsed_ms) + if args.swa_enabled and not args.ema_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + current = {k: v.detach().cpu().clone() for k, v in base_model.state_dict().items()} + if swa_state is None: + swa_state = current; swa_count = 1 + else: + inv = 1.0 / (swa_count + 1); keep = 1.0 - inv + for k, t in current.items(): + if torch.is_floating_point(swa_state[k]): swa_state[k].mul_(keep).add_(t, alpha=inv) + else: swa_state[k] = t + swa_count += 1 + zero_grad_all(); train_loss = torch.zeros((), device=device) + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = ms == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): loss = model(x, y) + train_loss += loss.detach(); (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + for group in optimizer_muon.param_groups: group["momentum"] = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for opt in optimizers: + for group in opt.param_groups: group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: opt.step() + if ema_state is not None: + d = args.ema_decay + with torch.no_grad(): + for n, t in base_model.state_dict().items(): + ema_state[n].mul_(d).add_(t.detach().float(), alpha=1.0 - d) + zero_grad_all(); step += 1 + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None): + log0(f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} train_time:{approx_ms:.0f}ms step_avg:{approx_ms / step:.2f}ms") + reached_cap = max_wallclock_ms is not None and approx_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + rct = torch.tensor(int(reached_cap), device=device); dist.all_reduce(rct, op=dist.ReduceOp.MAX); reached_cap = bool(rct.item()) + if stop_after_step is None and reached_cap: stop_after_step = step + + log0(f"peak memory: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") + if ema_state is not None: + log0("ema: applying EMA weights") + avg_sd = {n: t.to(dtype=base_model.state_dict()[n].dtype) for n, t in ema_state.items()} + base_model.load_state_dict(avg_sd, strict=True); del ema_state, avg_sd + elif swa_state is not None: + log0(f"swa: averaging {swa_count} checkpoints") + base_model.load_state_dict(swa_state, strict=True); del swa_state + + export_sd = base_model.state_dict() + if master_process: + torch.save(export_sd, "final_model.pt") + log0(f"Serialized model: {os.path.getsize('final_model.pt')} bytes Code: {len(code.encode('utf-8'))} bytes") + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + quant_result, quant_meta = mixed_quantize(sd_cpu, MIXED_QUANT_INT6_CATS) + qbuf = io.BytesIO(); torch.save({"w": quant_result, "m": quant_meta}, qbuf); qraw = qbuf.getvalue() + if HAS_ZSTD: qblob = zstd.ZstdCompressor(level=22).compress(qraw); cl = "zstd-22" + else: qblob = zlib.compress(qraw, level=9); cl = "zlib-9" + if master_process: + with open("final_model.int8.ptz", "wb") as f: f.write(qblob) + qfb = len(qblob); cb = len(code.encode("utf-8")) + log0(f"final_int8_zlib_roundtrip compressed_model_bytes:{qfb} code_bytes:{cb} total_artifact_bytes:{qfb + cb}") + log0(f"Serialized {cl}: {qfb} bytes Total: {qfb + cb} bytes") + if distributed: dist.barrier() + with open("final_model.int8.ptz", "rb") as f: qbd = f.read() + rd = zstd.ZstdDecompressor().decompress(qbd) if HAS_ZSTD else zlib.decompress(qbd) + qs = torch.load(io.BytesIO(rd), map_location="cpu") + deq_sd = dequantize_mixed(qs["w"], qs["m"], sd_cpu) + eval_model = GPT(vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, rope_dims=args.rope_dims, ln_scale=args.ln_scale).to(device).bfloat16() + for module in eval_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_sd, strict=True) + + if args.ttt_enabled: + log0(f"ttt: p1={args.ttt_p1_epochs}ep p2={args.ttt_epochs}ep freeze={args.ttt_freeze_blocks}") + torch.cuda.synchronize(); t_ttt = time.perf_counter() + ttt_adapt(args, eval_model, rank, world_size, device, val_tokens, log0) + torch.cuda.synchronize(); log0(f"ttt: total {1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + + torch.cuda.synchronize(); tqe = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + qvl, qvb = eval_val_sliding(args, eval_model, rank, world_size, device, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, stride=args.eval_stride, batch_seqs=args.eval_batch_seqs) + else: + qvl, qvb = eval_val(args, eval_model, rank, world_size, device, grad_accum_steps, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + torch.cuda.synchronize() + log0(f"final_roundtrip val_loss:{qvl:.4f} val_bpb:{qvb:.4f} eval_time:{1000.0 * (time.perf_counter() - tqe):.0f}ms") + log0(f"final_roundtrip_exact val_loss:{qvl:.8f} val_bpb:{qvb:.8f}") + if distributed: dist.destroy_process_group() + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-03-21_DominationV3/train_seed42.log b/records/track_10min_16mb/2026-03-21_DominationV3/train_seed42.log new file mode 100644 index 0000000000..120f0f9abc --- /dev/null +++ b/records/track_10min_16mb/2026-03-21_DominationV3/train_seed42.log @@ -0,0 +1,1101 @@ +"""Domination V3: compact no-TTT path for 10-min track.""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + HAS_ZSTD = True +except ImportError: + HAS_ZSTD = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 4096)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.5)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.01)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 25)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) + ttt_p1_epochs = int(os.environ.get("TTT_P1_EPOCHS", 100)) + ttt_p1_lr = float(os.environ.get("TTT_P1_LR", 0.01)) + rope_dims = int(os.environ.get("ROPE_DIMS", 0)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "0"))) + +# ----------------------------- +# MUON OPTIMIZER WITH WEIGHT DECAY +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay)) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: continue + lr, momentum = group["lr"], group["momentum"] + backend_steps, nesterov = group["backend_steps"], group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts(sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): base_bytes_np[token_id] = 1; continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): has_leading_space_np[token_id] = True; piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return (torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device)) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: raise ValueError(f"Validation split is too short for seq_len={seq_len}") + return tokens[: usable + 1] + +def eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut): + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError(f"VAL_BATCH_SIZE too small for world={world_size} accum={grad_accum_steps} seq={args.train_seq_len}") + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + vls = torch.zeros((), device=device, dtype=torch.float64) + vtc = torch.zeros((), device=device, dtype=torch.float64) + vbc = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for bs in range(seq_start, seq_end, local_batch_seqs): + be = min(bs + local_batch_seqs, seq_end) + rs, re = bs * args.train_seq_len, be * args.train_seq_len + 1 + local = val_tokens[rs:re].to(device=device, dtype=torch.int64, non_blocking=True) + x, y = local[:-1].reshape(-1, args.train_seq_len), local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + bl = model(x, y).detach() + bc = float(y.numel()) + vls += bl.to(torch.float64) * bc; vtc += bc + pi, ti = x.reshape(-1), y.reshape(-1) + tb = base_bytes_lut[ti].to(dtype=torch.int16) + tb += (has_leading_space_lut[ti] & ~is_boundary_token_lut[pi]).to(dtype=torch.int16) + vbc += tb.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(vls, op=dist.ReduceOp.SUM); dist.all_reduce(vtc, op=dist.ReduceOp.SUM); dist.all_reduce(vbc, op=dist.ReduceOp.SUM) + vl = vls / vtc; bpt = vl.item() / math.log(2.0); tpb = vtc.item() / vbc.item() + model.train(); return float(vl.item()), float(bpt * tpb) + +def eval_val_sliding(args, base_model, rank, world_size, device, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride, batch_seqs=32): + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) if min(ws + seq_len, total_tokens) - ws >= stride] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + ls = torch.zeros((), device=device, dtype=torch.float64) + tc = torch.zeros((), device=device, dtype=torch.float64) + bc = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + for bi in range(0, len(my_windows), batch_seqs): + bw = my_windows[bi:bi + batch_seqs] + bsz = len(bw) + xb = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + yb = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + for i, ws in enumerate(bw): + ep = min(ws + seq_len, total_tokens) + wl = ep - ws + wlens.append(wl) + ch = val_tokens[ws:ep + 1].to(dtype=torch.int64, device=device) + xb[i, :wl] = ch[:-1] + yb[i, :wl] = ch[1:] + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(xb) + nll = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), yb.reshape(-1), reduction="none").reshape(bsz, seq_len) + for i, ws in enumerate(bw): + wl = wlens[i] + s = 0 if ws == 0 else wl - stride + sn = nll[i, s:wl].to(torch.float64) + ls += sn.sum() + tc += float(wl - s) + tgt, prev = yb[i, s:wl], xb[i, s:wl] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + bc += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + rb = 0.0 + if tc.item() > 0: + rl = (ls / tc).item() + rb = rl / math.log(2.0) * (tc.item() / bc.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={rb:.6f}", flush=True) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(ls, op=dist.ReduceOp.SUM); dist.all_reduce(tc, op=dist.ReduceOp.SUM); dist.all_reduce(bc, op=dist.ReduceOp.SUM) + vl = (ls / tc).item(); bpt = vl / math.log(2.0); tpb = tc.item() / bc.item() + base_model.train(); return vl, bpt * tpb + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple(p for p in os.environ.get("CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale").split(",") if p) + +INT6_QUANT_RANGE = 31 +MIXED_QUANT_INT6_CATS = frozenset( + c.strip() for c in os.environ.get("MIXED_QUANT_INT6_CATS", "mlp,attn").split(",") if c.strip() +) +STE_QAT_ENABLED = bool(int(os.environ.get("STE_QAT_ENABLED", "0"))) +STE_QAT_RANGE = int(os.environ.get("STE_QAT_RANGE", INT6_QUANT_RANGE)) +FP16_PASSTHROUGH_PATTERNS = tuple( + p.strip() for p in os.environ.get("FP16_PASSTHROUGH_PATTERNS", "").split(",") if p.strip() +) + +def _classify_param(name): + if "tok_emb" in name or "lm_head" in name: return "embed" + if ".mlp." in name: return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): return "attn" + return "other" + +def _get_ste_range_for_param(name): + return STE_QAT_RANGE + +CLIP_PCTS = [0.999, 0.9995, 0.9999, 0.99999, 1.0] + +def quantize_int6_per_row(t): + t32 = t.float() + if t32.ndim == 2: + best_q, best_sc, best_mse = None, None, float("inf") + for pct in CLIP_PCTS: + rm = torch.quantile(t32.abs(), pct, dim=1) + sc = (rm / 31.0).clamp_min(1.0 / 31.0).to(torch.float16) + tc = torch.clamp(t32, -rm[:, None], rm[:, None]) + q = torch.clamp(torch.round(tc / sc.float()[:, None]), -32, 31).to(torch.int8) + recon = q.float() * sc.float()[:, None] + mse = (t32 - recon).square().mean().item() + if mse < best_mse: best_mse = mse; best_q = q; best_sc = sc + return best_q, best_sc + am = t32.abs().max().item() + sc = torch.tensor(am / 31.0 if am > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / sc.float()), -32, 31).to(torch.int8) + return q, sc + +def quantize_int8_per_row(t): + t32 = t.float() + if t32.ndim == 2: + rm = t32.abs().amax(dim=1) + sc = (rm / 127.0).clamp_min(1e-8).to(torch.float16) + q = torch.clamp(torch.round(t32 / sc.float()[:, None]), -127, 127).to(torch.int8) + return q, sc + am = t32.abs().max().item() + sc = torch.tensor(am / 127.0 if am > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / sc.float()), -127, 127).to(torch.int8) + return q, sc + +def mixed_quantize(state_dict, int6_cats): + result, meta = {}, {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if t.is_floating_point() and FP16_PASSTHROUGH_PATTERNS and any(p in name for p in FP16_PASSTHROUGH_PATTERNS): + result[name] = t.to(torch.float16) + meta[name] = "passthrough_fp16" + continue + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q; result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_int8_per_row(t) + result[name + ".q"] = q; result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed(result, meta, template_sd): + out = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: continue + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig.dtype in (torch.float32, torch.bfloat16): + t = t.to(orig.dtype) + out[name] = t; continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig.dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig.dtype) + return out + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file): + hb = 256 * np.dtype(" 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: self._advance_file(); continue + k = min(remaining, avail); chunks.append(self.tokens[self.pos:self.pos + k]); self.pos += k; remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern, rank, world_size, device): + self.rank, self.world_size, self.device, self.stream = rank, world_size, device, TokenStream(pattern) + def next_batch(self, global_tokens, seq_len, grad_accum_steps): + lt = global_tokens // (self.world_size * grad_accum_steps); prs = lt + 1 + chunk = self.stream.take(prs * self.world_size); s = self.rank * prs + local = chunk[s:s + prs].to(dtype=torch.int64) + x, y = local[:-1].reshape(-1, seq_len), local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps=None): super().__init__(); self.eps = eps + def forward(self, x): return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +class CastedLinear(nn.Linear): + def forward(self, x): + w = self.weight.to(x.dtype) + if self.training and STE_QAT_ENABLED and w.ndim == 2: + with torch.no_grad(): + w32 = w.float(); rm = w32.abs().amax(dim=1).clamp_min(1e-8) + sc = rm / 31.0; wc = torch.clamp(w32, -rm[:, None], rm[:, None]) + wq = (torch.round(wc / sc[:, None]) * sc[:, None]).to(x.dtype) + w = w + (wq - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + +def restore_low_dim_params_to_fp32(module): + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + +class Rotary(nn.Module): + def __init__(self, dim, base=10000.0, train_seq_len=1024, rope_dims=0): + super().__init__() + self._base = base; self._train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + rd = self.rope_dims + self.register_buffer("inv_freq", 1.0 / (base ** (torch.arange(0, rd, 2, dtype=torch.float32) / rd)), persistent=False) + self._seq_len_cached = 0; self._cos_cached = None; self._sin_cached = None + def forward(self, seq_len, device, dtype): + if self._cos_cached is None or self._seq_len_cached != seq_len or self._cos_cached.device != device: + rd = self.rope_dims + if seq_len > self._train_seq_len: + scale = seq_len / self._train_seq_len + new_base = self._base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, None, :, :]; self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + +def apply_rotary_emb(x, cos, sin, rope_dims=0): + if rope_dims > 0 and rope_dims < x.size(-1): + xr, xp = x[..., :rope_dims], x[..., rope_dims:] + h = rope_dims // 2; x1, x2 = xr[..., :h], xr[..., h:] + xr = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((xr, xp), dim=-1) + h = x.size(-1) // 2; x1, x2 = x[..., :h], x[..., h:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_dims=0): + super().__init__() + self.num_heads, self.num_kv_heads = num_heads, num_kv_heads; self.head_dim = dim // num_heads + kv_dim = num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False); self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False); self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = rope_dims + self.rotary = Rotary(self.head_dim, base=rope_base, rope_dims=rope_dims) + self.use_xsa = False + + def _xsa_efficient(self, y, v): + B, T, H, D = y.shape + Hkv = v.size(2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x): + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True, enable_gqa=(self.num_kv_heads != self.num_heads)) + y = y.transpose(1, 2).contiguous() + if self.use_xsa: + v_bthd = v.transpose(1, 2).contiguous() + y = self._xsa_efficient(y, v_bthd) + return self.proj(y.reshape(bsz, seqlen, dim)) + +class MLP(nn.Module): + def __init__(self, dim, mlp_mult): + super().__init__(); hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False); self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x): return self.proj(torch.relu(self.fc(x)).square()) + +class SmearGate(nn.Module): + """Per-dimension SmearGate (from PR #194): each dim has its own blend ratio.""" + def __init__(self, dim): + super().__init__(); self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x): + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size, bigram_dim, model_dim): + super().__init__(); self.bvs = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim); nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def forward(self, token_ids): + t = token_ids.to(torch.int32); mod = self.bvs - 1 + out = torch.empty_like(t); out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + h = self.embed(out.long()) + if self.proj is not None: h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class Block(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, rope_dims=0, ln_sf=1.0): + super().__init__() + self.attn_norm, self.mlp_norm = RMSNorm(), RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_dims=rope_dims) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_sf = ln_sf + def forward(self, x, x0): + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + s = self.ln_sf + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * self.attn(self.attn_norm(x) * s) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * s) + return x + +class GPT(nn.Module): + def __init__(self, vocab_size, num_layers, model_dim, num_heads, num_kv_heads, mlp_mult, + tie_embeddings, tied_embed_init_std, logit_softcap, rope_base, qk_gain_init, + bigram_vocab_size=0, bigram_dim=128, xsa_last_n=0, rope_dims=0, ln_scale=False): + super().__init__() + self.tie_embeddings, self.tied_embed_init_std, self.logit_softcap = tie_embeddings, tied_embed_init_std, logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, rope_dims=rope_dims, ln_sf=1.0 / math.sqrt(i + 1) if ln_scale else 1.0) for i in range(num_layers)]) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: self.lm_head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + + def _init_weights(self): + if self.tie_embeddings: nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + nl = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): module.weight.mul_(1.0 / math.sqrt(2 * nl)) + + def _run_blocks(self, x, x0): + skips = [] + for i in range(self.num_encoder_layers): x = self.blocks[i](x, x0); skips.append(x) + for i in range(self.num_decoder_layers): + if skips: x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + return x + + def _embed(self, input_ids): + x = self.tok_emb(input_ids) + if self.bigram is not None: x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + return self.smear(x) + + def _logits(self, x): + if self.tie_embeddings: lp = F.linear(x, self.tok_emb.weight) + else: lp = self.lm_head(x) + return self.logit_softcap * torch.tanh(lp / self.logit_softcap) + + def forward(self, input_ids, target_ids): + x0 = self._embed(input_ids) + x = self.final_norm(self._run_blocks(x0, x0)) + x_flat = x.reshape(-1, x.size(-1)) + return F.cross_entropy(self._logits(x_flat).float(), target_ids.reshape(-1), reduction="mean") + + def forward_logits(self, input_ids): + x0 = self._embed(input_ids) + return self._logits(self.final_norm(self._run_blocks(x0, x0))) + +# ----------------------------- +# TRAINING +# ----------------------------- + +def _ttt_run(mdl, opt, epochs, rank, world_size, device, val_tokens, sl, batch_seqs): + nt = val_tokens.numel() - 1; ts = nt // sl + ms, me = (ts * rank) // world_size, (ts * (rank + 1)) // world_size + mdl.train() + for _ in range(epochs): + for bs in range(ms, me, batch_seqs): + be = min(bs + batch_seqs, me) + rs, re = bs * sl, be * sl + 1 + loc = val_tokens[rs:re].to(device=device, dtype=torch.int64, non_blocking=True) + x, y = loc[:-1].reshape(-1, sl), loc[1:].reshape(-1, sl) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): loss = mdl(x, y) + opt.zero_grad(set_to_none=True); loss.backward(); opt.step() + if dist.is_available() and dist.is_initialized(): dist.barrier() + +def ttt_adapt(args, mdl, rank, world_size, device, val_tokens, log0): + sl = args.train_seq_len + for p in mdl.parameters(): p.requires_grad_(False) + if args.ttt_p1_epochs > 0: + norm_params = [] + for n, p in mdl.named_parameters(): + if "norm" in n or "attn_scale" in n or "mlp_scale" in n or "resid_mix" in n or "skip_weight" in n: + p.requires_grad_(True); norm_params.append(p) + log0(f"ttt_p1: {args.ttt_p1_epochs}ep Adam lr={args.ttt_p1_lr} params={sum(p.numel() for p in norm_params)}") + opt1 = torch.optim.Adam(norm_params, lr=args.ttt_p1_lr) + torch.cuda.synchronize(); t0 = time.perf_counter() + _ttt_run(mdl, opt1, args.ttt_p1_epochs, rank, world_size, device, val_tokens, sl, args.ttt_batch_seqs) + torch.cuda.synchronize(); log0(f"ttt_p1: done in {1000.0 * (time.perf_counter() - t0):.0f}ms") + for p in norm_params: p.requires_grad_(False) + del opt1 + if args.ttt_epochs > 0: + nl = len(mdl.blocks) + for i, b in enumerate(mdl.blocks): + req = i >= args.ttt_freeze_blocks + for p in b.parameters(): p.requires_grad_(req) + trainable = [p for p in mdl.parameters() if p.requires_grad] + log0(f"ttt_p2: {args.ttt_epochs}ep SGD lr={args.ttt_lr} freeze={args.ttt_freeze_blocks} params={sum(p.numel() for p in trainable)}") + opt2 = torch.optim.SGD(trainable, lr=args.ttt_lr, momentum=args.ttt_momentum) + torch.cuda.synchronize(); t0 = time.perf_counter() + _ttt_run(mdl, opt2, args.ttt_epochs, rank, world_size, device, val_tokens, sl, args.ttt_batch_seqs) + torch.cuda.synchronize(); log0(f"ttt_p2: done in {1000.0 * (time.perf_counter() - t0):.0f}ms") + del opt2 + for p in mdl.parameters(): p.requires_grad_(False) + mdl.eval() + +def main(): + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8"); args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank, world_size = int(os.environ.get("RANK", "0")), int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + grad_accum_steps = 8 // world_size; grad_scale = 1.0 / grad_accum_steps + device = torch.device("cuda", local_rank); torch.cuda.set_device(device) + if distributed: dist.init_process_group(backend="nccl", device_id=device); dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True; torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False); enable_flash_sdp(True); enable_mem_efficient_sdp(False); enable_math_sdp(False) + logfile = None + if master_process: os.makedirs("logs", exist_ok=True); logfile = f"logs/{args.run_id}.txt"; print(logfile) + def log0(msg, console=True): + if not master_process: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + log0(code, console=False); log0("=" * 100, console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + log0("=" * 100, console=False) + random.seed(args.seed); np.random.seed(args.seed); torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:tokens:{val_tokens.numel() - 1}") + + base_model = GPT(vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, rope_dims=args.rope_dims, ln_scale=args.ln_scale).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [p for n, p in block_named_params if p.ndim == 2 and not any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = [p for n, p in block_named_params if p.ndim < 2 or any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + if base_model.skip_weights.numel() > 0: scalar_params.append(base_model.skip_weights) + for p in base_model.smear.parameters(): scalar_params.append(p) + if base_model.bigram is not None: + for n, p in base_model.bigram.named_parameters(): + if p.ndim == 2 and p.shape[0] >= 64 and p.shape[1] >= 64: matrix_params.append(p) + else: scalar_params.append(p) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.AdamW([{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=args.weight_decay, fused=True) + optimizer_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, weight_decay=args.muon_wd) + for group in optimizer_muon.param_groups: group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW([{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=args.weight_decay, fused=True) + optimizers = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizers.insert(1, torch.optim.AdamW([{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=args.weight_decay, fused=True)) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} swa:{args.swa_enabled} compression:{'zstd-22' if HAS_ZSTD else 'zlib-9'}") + log0(f"bigram_vocab:{args.bigram_vocab_size} bigram_dim:{args.bigram_dim} grad_clip:{args.grad_clip_norm} muon_wd:{args.muon_wd}") + log0(f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} warmdown:{args.warmdown_iters} seed:{args.seed}") + log0(f"ste_qat:{STE_QAT_ENABLED} ste_range:{STE_QAT_RANGE} int6_cats:{MIXED_QUANT_INT6_CATS}") + log0(f"fp16_passthrough:{FP16_PASSTHROUGH_PATTERNS}") + log0(f"xsa_last_n:{args.xsa_last_n} ema:{args.ema_enabled} ema_decay:{args.ema_decay} rope_dims:{args.rope_dims} ln_scale:{args.ln_scale}") + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all(): + for opt in optimizers: opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step, elapsed_ms): + if args.warmdown_iters <= 0: return 1.0 + if max_wallclock_ms is None: + ws = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if ws <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1); wms = args.warmdown_iters * step_ms + rms = max(max_wallclock_ms - elapsed_ms, 0.0) + return rms / max(wms, 1e-9) if rms <= wms else 1.0 + + if args.warmup_steps > 0: + init_sd = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers]; model.train() + for ws in range(args.warmup_steps): + zero_grad_all() + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = ms == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): wl = model(x, y) + (wl * grad_scale).backward() + for o in optimizers: o.step() + zero_grad_all() + if args.warmup_steps <= 20 or (ws + 1) % 10 == 0: log0(f"warmup_step:{ws + 1}/{args.warmup_steps}") + base_model.load_state_dict(init_sd, strict=True) + for o, s in zip(optimizers, init_opts, strict=True): o.load_state_dict(s) + zero_grad_all() + if distributed: model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + training_time_ms = 0.0; stop_after_step = None; swa_state = None; swa_count = 0 + ema_state = None + if args.ema_enabled: + ema_state = {n: t.detach().float().clone() for n, t in base_model.state_dict().items()} + torch.cuda.synchronize(); t0 = time.perf_counter(); step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + torch.cuda.synchronize(); training_time_ms += 1000.0 * (time.perf_counter() - t0) + vl, vb = eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vb:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms") + torch.cuda.synchronize(); t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms step:{step}/{args.iterations}") + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0); scale = lr_mul(step, elapsed_ms) + if args.swa_enabled and not args.ema_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + current = {k: v.detach().cpu().clone() for k, v in base_model.state_dict().items()} + if swa_state is None: + swa_state = current; swa_count = 1 + else: + inv = 1.0 / (swa_count + 1); keep = 1.0 - inv + for k, t in current.items(): + if torch.is_floating_point(swa_state[k]): swa_state[k].mul_(keep).add_(t, alpha=inv) + else: swa_state[k] = t + swa_count += 1 + zero_grad_all(); train_loss = torch.zeros((), device=device) + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = ms == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): loss = model(x, y) + train_loss += loss.detach(); (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + for group in optimizer_muon.param_groups: group["momentum"] = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for opt in optimizers: + for group in opt.param_groups: group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: opt.step() + if ema_state is not None: + d = args.ema_decay + with torch.no_grad(): + for n, t in base_model.state_dict().items(): + ema_state[n].mul_(d).add_(t.detach().float(), alpha=1.0 - d) + zero_grad_all(); step += 1 + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None): + log0(f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} train_time:{approx_ms:.0f}ms step_avg:{approx_ms / step:.2f}ms") + reached_cap = max_wallclock_ms is not None and approx_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + rct = torch.tensor(int(reached_cap), device=device); dist.all_reduce(rct, op=dist.ReduceOp.MAX); reached_cap = bool(rct.item()) + if stop_after_step is None and reached_cap: stop_after_step = step + + log0(f"peak memory: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") + if ema_state is not None: + log0("ema: applying EMA weights") + avg_sd = {n: t.to(dtype=base_model.state_dict()[n].dtype) for n, t in ema_state.items()} + base_model.load_state_dict(avg_sd, strict=True); del ema_state, avg_sd + elif swa_state is not None: + log0(f"swa: averaging {swa_count} checkpoints") + base_model.load_state_dict(swa_state, strict=True); del swa_state + + export_sd = base_model.state_dict() + if master_process: + torch.save(export_sd, "final_model.pt") + log0(f"Serialized model: {os.path.getsize('final_model.pt')} bytes Code: {len(code.encode('utf-8'))} bytes") + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + quant_result, quant_meta = mixed_quantize(sd_cpu, MIXED_QUANT_INT6_CATS) + qbuf = io.BytesIO(); torch.save({"w": quant_result, "m": quant_meta}, qbuf); qraw = qbuf.getvalue() + if HAS_ZSTD: qblob = zstd.ZstdCompressor(level=22).compress(qraw); cl = "zstd-22" + else: qblob = zlib.compress(qraw, level=9); cl = "zlib-9" + if master_process: + with open("final_model.int8.ptz", "wb") as f: f.write(qblob) + qfb = len(qblob); cb = len(code.encode("utf-8")) + log0(f"final_int8_zlib_roundtrip compressed_model_bytes:{qfb} code_bytes:{cb} total_artifact_bytes:{qfb + cb}") + log0(f"Serialized {cl}: {qfb} bytes Total: {qfb + cb} bytes") + if distributed: dist.barrier() + with open("final_model.int8.ptz", "rb") as f: qbd = f.read() + rd = zstd.ZstdDecompressor().decompress(qbd) if HAS_ZSTD else zlib.decompress(qbd) + qs = torch.load(io.BytesIO(rd), map_location="cpu") + deq_sd = dequantize_mixed(qs["w"], qs["m"], sd_cpu) + eval_model = GPT(vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, rope_dims=args.rope_dims, ln_scale=args.ln_scale).to(device).bfloat16() + for module in eval_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_sd, strict=True) + + if args.ttt_enabled: + log0(f"ttt: p1={args.ttt_p1_epochs}ep p2={args.ttt_epochs}ep freeze={args.ttt_freeze_blocks}") + torch.cuda.synchronize(); t_ttt = time.perf_counter() + ttt_adapt(args, eval_model, rank, world_size, device, val_tokens, log0) + torch.cuda.synchronize(); log0(f"ttt: total {1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + + torch.cuda.synchronize(); tqe = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + qvl, qvb = eval_val_sliding(args, eval_model, rank, world_size, device, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, stride=args.eval_stride, batch_seqs=args.eval_batch_seqs) + else: + qvl, qvb = eval_val(args, eval_model, rank, world_size, device, grad_accum_steps, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + torch.cuda.synchronize() + log0(f"final_roundtrip val_loss:{qvl:.4f} val_bpb:{qvb:.4f} eval_time:{1000.0 * (time.perf_counter() - tqe):.0f}ms") + log0(f"final_roundtrip_exact val_loss:{qvl:.8f} val_bpb:{qvb:.8f}") + if distributed: dist.destroy_process_group() + +if __name__ == "__main__": + main() + +==================================================================================================== +Mon Mar 23 02:13:23 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.95.05 Driver Version: 580.95.05 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 32C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 31C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 33C P0 123W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:99:00.0 Off | 0 | +| N/A 30C P0 112W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | +| N/A 47C P0 126W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:AF:00.0 Off | 0 | +| N/A 32C P0 122W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:B3:00.0 Off | 0 | +| N/A 33C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:B7:00.0 Off | 0 | +| N/A 31C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 1 C /bin/dumb-init 1512MiB | +| 1 N/A N/A 1 C /bin/dumb-init 1512MiB | +| 2 N/A N/A 1 C /bin/dumb-init 1512MiB | +| 3 N/A N/A 1 C /bin/dumb-init 1512MiB | +| 4 N/A N/A 1 C /bin/dumb-init 1512MiB | +| 5 N/A N/A 1 C /bin/dumb-init 1512MiB | +| 6 N/A N/A 1 C /bin/dumb-init 1512MiB | +| 7 N/A N/A 1 C /bin/dumb-init 1512MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:tokens:62021632 +model_params:27092057 swa:False compression:zstd-22 +bigram_vocab:4096 bigram_dim:128 grad_clip:0.3 muon_wd:0.04 +train_batch_tokens:524288 train_seq_len:2048 warmdown:3000 seed:42 +ste_qat:False ste_range:31 int6_cats:frozenset({'mlp', 'tok_emb', 'attn'}) +fp16_passthrough:() +xsa_last_n:0 ema:True ema_decay:0.997 rope_dims:16 ln_scale:True +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9320 val_bpb:4.1055 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9334 train_time:224ms step_avg:224.14ms +step:2/20000 train_loss:8.8138 train_time:292ms step_avg:146.01ms +step:3/20000 train_loss:7.9651 train_time:370ms step_avg:123.32ms +step:4/20000 train_loss:7.2899 train_time:440ms step_avg:110.05ms +step:5/20000 train_loss:6.9610 train_time:510ms step_avg:101.98ms +step:6/20000 train_loss:7.7778 train_time:580ms step_avg:96.71ms +step:7/20000 train_loss:6.7681 train_time:651ms step_avg:92.93ms +step:8/20000 train_loss:6.6433 train_time:722ms step_avg:90.21ms +step:9/20000 train_loss:6.4714 train_time:794ms step_avg:88.26ms +step:10/20000 train_loss:6.2261 train_time:863ms step_avg:86.32ms +step:100/20000 train_loss:3.3053 train_time:7229ms step_avg:72.29ms +step:200/20000 train_loss:2.7926 train_time:14389ms step_avg:71.94ms +step:300/20000 train_loss:2.4213 train_time:21439ms step_avg:71.46ms +step:400/20000 train_loss:2.2810 train_time:28602ms step_avg:71.51ms +step:500/20000 train_loss:2.4250 train_time:35604ms step_avg:71.21ms +step:500/20000 val_loss:2.4247 val_bpb:1.4360 train_time:35609ms step_avg:71.22ms +step:600/20000 train_loss:2.4825 train_time:42776ms step_avg:71.29ms +step:700/20000 train_loss:2.3779 train_time:49826ms step_avg:71.18ms +step:800/20000 train_loss:2.2242 train_time:56988ms step_avg:71.23ms +step:900/20000 train_loss:2.2840 train_time:64004ms step_avg:71.12ms +step:1000/20000 train_loss:2.3249 train_time:71186ms step_avg:71.19ms +step:1000/20000 val_loss:2.2774 val_bpb:1.3488 train_time:71192ms step_avg:71.19ms +step:1100/20000 train_loss:2.2035 train_time:78347ms step_avg:71.22ms +step:1200/20000 train_loss:2.3489 train_time:85551ms step_avg:71.29ms +step:1300/20000 train_loss:2.3228 train_time:92576ms step_avg:71.21ms +step:1400/20000 train_loss:2.3818 train_time:99777ms step_avg:71.27ms +step:1500/20000 train_loss:2.1873 train_time:106860ms step_avg:71.24ms +step:1500/20000 val_loss:2.2287 val_bpb:1.3200 train_time:106867ms step_avg:71.24ms +step:1600/20000 train_loss:2.0546 train_time:114020ms step_avg:71.26ms +step:1700/20000 train_loss:2.1174 train_time:121049ms step_avg:71.21ms +step:1800/20000 train_loss:2.1536 train_time:128192ms step_avg:71.22ms +step:1900/20000 train_loss:2.1455 train_time:135229ms step_avg:71.17ms +step:2000/20000 train_loss:2.1899 train_time:142415ms step_avg:71.21ms +step:2000/20000 val_loss:2.1726 val_bpb:1.2868 train_time:142421ms step_avg:71.21ms +step:2100/20000 train_loss:2.2101 train_time:149610ms step_avg:71.24ms +step:2200/20000 train_loss:2.0134 train_time:156668ms step_avg:71.21ms +step:2300/20000 train_loss:2.3090 train_time:163845ms step_avg:71.24ms +step:2400/20000 train_loss:2.1328 train_time:170886ms step_avg:71.20ms +step:2500/20000 train_loss:2.0697 train_time:178042ms step_avg:71.22ms +step:2500/20000 val_loss:2.1436 val_bpb:1.2695 train_time:178046ms step_avg:71.22ms +step:2600/20000 train_loss:2.3660 train_time:185098ms step_avg:71.19ms +step:2700/20000 train_loss:2.0864 train_time:192248ms step_avg:71.20ms +step:2800/20000 train_loss:2.1731 train_time:199295ms step_avg:71.18ms +step:2900/20000 train_loss:2.1187 train_time:206468ms step_avg:71.20ms +step:3000/20000 train_loss:2.1616 train_time:213512ms step_avg:71.17ms +step:3000/20000 val_loss:2.1279 val_bpb:1.2602 train_time:213519ms step_avg:71.17ms +step:3100/20000 train_loss:2.1327 train_time:220701ms step_avg:71.19ms +step:3200/20000 train_loss:2.1290 train_time:227724ms step_avg:71.16ms +step:3300/20000 train_loss:2.1716 train_time:234878ms step_avg:71.18ms +step:3400/20000 train_loss:2.0996 train_time:241905ms step_avg:71.15ms +step:3500/20000 train_loss:2.1931 train_time:249032ms step_avg:71.15ms +step:3500/20000 val_loss:2.1180 val_bpb:1.2544 train_time:249036ms step_avg:71.15ms +step:3600/20000 train_loss:2.0394 train_time:256048ms step_avg:71.12ms +step:3700/20000 train_loss:2.0784 train_time:263183ms step_avg:71.13ms +step:3800/20000 train_loss:2.1495 train_time:270208ms step_avg:71.11ms +step:3900/20000 train_loss:1.9337 train_time:277367ms step_avg:71.12ms +step:4000/20000 train_loss:2.1223 train_time:284402ms step_avg:71.10ms +step:4000/20000 val_loss:2.1085 val_bpb:1.2488 train_time:284407ms step_avg:71.10ms +step:4100/20000 train_loss:2.1317 train_time:291566ms step_avg:71.11ms +step:4200/20000 train_loss:2.1145 train_time:298708ms step_avg:71.12ms +step:4300/20000 train_loss:1.9561 train_time:305759ms step_avg:71.11ms +step:4400/20000 train_loss:2.0522 train_time:312937ms step_avg:71.12ms +step:4500/20000 train_loss:2.2016 train_time:319974ms step_avg:71.11ms +step:4500/20000 val_loss:2.1042 val_bpb:1.2462 train_time:319979ms step_avg:71.11ms +step:4600/20000 train_loss:1.9152 train_time:327139ms step_avg:71.12ms +step:4700/20000 train_loss:2.2166 train_time:334170ms step_avg:71.10ms +step:4800/20000 train_loss:2.2026 train_time:341316ms step_avg:71.11ms +step:4900/20000 train_loss:2.1143 train_time:348337ms step_avg:71.09ms +step:5000/20000 train_loss:1.9593 train_time:355539ms step_avg:71.11ms +step:5000/20000 val_loss:2.0978 val_bpb:1.2425 train_time:355543ms step_avg:71.11ms +step:5100/20000 train_loss:1.9749 train_time:362579ms step_avg:71.09ms +step:5200/20000 train_loss:2.1228 train_time:369700ms step_avg:71.10ms +step:5300/20000 train_loss:2.1534 train_time:376725ms step_avg:71.08ms +step:5400/20000 train_loss:2.1396 train_time:383900ms step_avg:71.09ms +step:5500/20000 train_loss:2.0889 train_time:390929ms step_avg:71.08ms +step:5500/20000 val_loss:2.0950 val_bpb:1.2408 train_time:390936ms step_avg:71.08ms +step:5600/20000 train_loss:2.1266 train_time:398107ms step_avg:71.09ms +step:5700/20000 train_loss:2.1149 train_time:405131ms step_avg:71.08ms +step:5800/20000 train_loss:2.0709 train_time:412298ms step_avg:71.09ms +step:5900/20000 train_loss:2.0259 train_time:419311ms step_avg:71.07ms +step:6000/20000 train_loss:2.1469 train_time:426434ms step_avg:71.07ms +step:6000/20000 val_loss:2.0752 val_bpb:1.2291 train_time:426441ms step_avg:71.07ms +step:6100/20000 train_loss:2.0498 train_time:433461ms step_avg:71.06ms +step:6200/20000 train_loss:2.0181 train_time:440588ms step_avg:71.06ms +step:6300/20000 train_loss:1.9535 train_time:447754ms step_avg:71.07ms +step:6400/20000 train_loss:2.0879 train_time:454780ms step_avg:71.06ms +step:6500/20000 train_loss:1.9986 train_time:461927ms step_avg:71.07ms +step:6500/20000 val_loss:2.0524 val_bpb:1.2155 train_time:461932ms step_avg:71.07ms +step:6600/20000 train_loss:2.0320 train_time:468945ms step_avg:71.05ms +step:6700/20000 train_loss:2.0699 train_time:476080ms step_avg:71.06ms +step:6800/20000 train_loss:2.0839 train_time:483121ms step_avg:71.05ms +step:6900/20000 train_loss:2.0006 train_time:490237ms step_avg:71.05ms +step:7000/20000 train_loss:2.1253 train_time:497284ms step_avg:71.04ms +step:7000/20000 val_loss:2.0272 val_bpb:1.2006 train_time:497290ms step_avg:71.04ms +step:7100/20000 train_loss:1.9569 train_time:504475ms step_avg:71.05ms +step:7200/20000 train_loss:2.0891 train_time:511500ms step_avg:71.04ms +step:7300/20000 train_loss:1.9770 train_time:518644ms step_avg:71.05ms +step:7400/20000 train_loss:2.0012 train_time:525684ms step_avg:71.04ms +step:7500/20000 train_loss:1.9895 train_time:532864ms step_avg:71.05ms +step:7500/20000 val_loss:1.9991 val_bpb:1.1840 train_time:532869ms step_avg:71.05ms +step:7600/20000 train_loss:1.8682 train_time:539901ms step_avg:71.04ms +step:7700/20000 train_loss:1.9432 train_time:547070ms step_avg:71.05ms +step:7800/20000 train_loss:2.0085 train_time:554074ms step_avg:71.04ms +step:7900/20000 train_loss:1.9855 train_time:561249ms step_avg:71.04ms +step:8000/20000 train_loss:1.9670 train_time:568289ms step_avg:71.04ms +step:8000/20000 val_loss:1.9672 val_bpb:1.1651 train_time:568295ms step_avg:71.04ms +step:8100/20000 train_loss:1.9936 train_time:575470ms step_avg:71.05ms +step:8200/20000 train_loss:2.0267 train_time:582507ms step_avg:71.04ms +step:8300/20000 train_loss:1.9434 train_time:589659ms step_avg:71.04ms +step:8400/20000 train_loss:1.9615 train_time:596816ms step_avg:71.05ms +step:8443/20000 val_loss:1.9431 val_bpb:1.1508 train_time:599822ms step_avg:71.04ms +stopping_early: wallclock_cap train_time:599822ms step:8443/20000 +peak memory: 14059 MiB +ema: applying EMA weights +Serialized model: 106313663 bytes Code: 49961 bytes +final_int8_zlib_roundtrip compressed_model_bytes:15756295 code_bytes:49961 total_artifact_bytes:15806256 +Serialized zstd-22: 15756295 bytes Total: 15806256 bytes +ttt: p1=0ep p2=25ep freeze=0 +ttt_p2: 25ep SGD lr=0.012 freeze=0 params=25974872 +ttt_p2: done in 387926ms +ttt: total 387930ms +final_eval_mode:sliding_window stride:64 batch_seqs:32 +final_roundtrip val_loss:1.8984 val_bpb:1.1243 eval_time:186795ms +final_roundtrip_exact val_loss:1.89835246 val_bpb:1.12431423 diff --git a/records/track_10min_16mb/2026-03-21_DominationV3/train_seed7.log b/records/track_10min_16mb/2026-03-21_DominationV3/train_seed7.log new file mode 100644 index 0000000000..a58250de76 --- /dev/null +++ b/records/track_10min_16mb/2026-03-21_DominationV3/train_seed7.log @@ -0,0 +1,1101 @@ +"""Domination V3: compact no-TTT path for 10-min track.""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + HAS_ZSTD = True +except ImportError: + HAS_ZSTD = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 4096)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.5)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.01)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 25)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) + ttt_p1_epochs = int(os.environ.get("TTT_P1_EPOCHS", 100)) + ttt_p1_lr = float(os.environ.get("TTT_P1_LR", 0.01)) + rope_dims = int(os.environ.get("ROPE_DIMS", 0)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "0"))) + +# ----------------------------- +# MUON OPTIMIZER WITH WEIGHT DECAY +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay)) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: continue + lr, momentum = group["lr"], group["momentum"] + backend_steps, nesterov = group["backend_steps"], group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts(sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): base_bytes_np[token_id] = 1; continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): has_leading_space_np[token_id] = True; piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return (torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device)) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: raise ValueError(f"Validation split is too short for seq_len={seq_len}") + return tokens[: usable + 1] + +def eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut): + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError(f"VAL_BATCH_SIZE too small for world={world_size} accum={grad_accum_steps} seq={args.train_seq_len}") + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + vls = torch.zeros((), device=device, dtype=torch.float64) + vtc = torch.zeros((), device=device, dtype=torch.float64) + vbc = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for bs in range(seq_start, seq_end, local_batch_seqs): + be = min(bs + local_batch_seqs, seq_end) + rs, re = bs * args.train_seq_len, be * args.train_seq_len + 1 + local = val_tokens[rs:re].to(device=device, dtype=torch.int64, non_blocking=True) + x, y = local[:-1].reshape(-1, args.train_seq_len), local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + bl = model(x, y).detach() + bc = float(y.numel()) + vls += bl.to(torch.float64) * bc; vtc += bc + pi, ti = x.reshape(-1), y.reshape(-1) + tb = base_bytes_lut[ti].to(dtype=torch.int16) + tb += (has_leading_space_lut[ti] & ~is_boundary_token_lut[pi]).to(dtype=torch.int16) + vbc += tb.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(vls, op=dist.ReduceOp.SUM); dist.all_reduce(vtc, op=dist.ReduceOp.SUM); dist.all_reduce(vbc, op=dist.ReduceOp.SUM) + vl = vls / vtc; bpt = vl.item() / math.log(2.0); tpb = vtc.item() / vbc.item() + model.train(); return float(vl.item()), float(bpt * tpb) + +def eval_val_sliding(args, base_model, rank, world_size, device, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride, batch_seqs=32): + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) if min(ws + seq_len, total_tokens) - ws >= stride] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + ls = torch.zeros((), device=device, dtype=torch.float64) + tc = torch.zeros((), device=device, dtype=torch.float64) + bc = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + for bi in range(0, len(my_windows), batch_seqs): + bw = my_windows[bi:bi + batch_seqs] + bsz = len(bw) + xb = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + yb = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + for i, ws in enumerate(bw): + ep = min(ws + seq_len, total_tokens) + wl = ep - ws + wlens.append(wl) + ch = val_tokens[ws:ep + 1].to(dtype=torch.int64, device=device) + xb[i, :wl] = ch[:-1] + yb[i, :wl] = ch[1:] + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(xb) + nll = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), yb.reshape(-1), reduction="none").reshape(bsz, seq_len) + for i, ws in enumerate(bw): + wl = wlens[i] + s = 0 if ws == 0 else wl - stride + sn = nll[i, s:wl].to(torch.float64) + ls += sn.sum() + tc += float(wl - s) + tgt, prev = yb[i, s:wl], xb[i, s:wl] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + bc += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + rb = 0.0 + if tc.item() > 0: + rl = (ls / tc).item() + rb = rl / math.log(2.0) * (tc.item() / bc.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={rb:.6f}", flush=True) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(ls, op=dist.ReduceOp.SUM); dist.all_reduce(tc, op=dist.ReduceOp.SUM); dist.all_reduce(bc, op=dist.ReduceOp.SUM) + vl = (ls / tc).item(); bpt = vl / math.log(2.0); tpb = tc.item() / bc.item() + base_model.train(); return vl, bpt * tpb + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple(p for p in os.environ.get("CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale").split(",") if p) + +INT6_QUANT_RANGE = 31 +MIXED_QUANT_INT6_CATS = frozenset( + c.strip() for c in os.environ.get("MIXED_QUANT_INT6_CATS", "mlp,attn").split(",") if c.strip() +) +STE_QAT_ENABLED = bool(int(os.environ.get("STE_QAT_ENABLED", "0"))) +STE_QAT_RANGE = int(os.environ.get("STE_QAT_RANGE", INT6_QUANT_RANGE)) +FP16_PASSTHROUGH_PATTERNS = tuple( + p.strip() for p in os.environ.get("FP16_PASSTHROUGH_PATTERNS", "").split(",") if p.strip() +) + +def _classify_param(name): + if "tok_emb" in name or "lm_head" in name: return "embed" + if ".mlp." in name: return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): return "attn" + return "other" + +def _get_ste_range_for_param(name): + return STE_QAT_RANGE + +CLIP_PCTS = [0.999, 0.9995, 0.9999, 0.99999, 1.0] + +def quantize_int6_per_row(t): + t32 = t.float() + if t32.ndim == 2: + best_q, best_sc, best_mse = None, None, float("inf") + for pct in CLIP_PCTS: + rm = torch.quantile(t32.abs(), pct, dim=1) + sc = (rm / 31.0).clamp_min(1.0 / 31.0).to(torch.float16) + tc = torch.clamp(t32, -rm[:, None], rm[:, None]) + q = torch.clamp(torch.round(tc / sc.float()[:, None]), -32, 31).to(torch.int8) + recon = q.float() * sc.float()[:, None] + mse = (t32 - recon).square().mean().item() + if mse < best_mse: best_mse = mse; best_q = q; best_sc = sc + return best_q, best_sc + am = t32.abs().max().item() + sc = torch.tensor(am / 31.0 if am > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / sc.float()), -32, 31).to(torch.int8) + return q, sc + +def quantize_int8_per_row(t): + t32 = t.float() + if t32.ndim == 2: + rm = t32.abs().amax(dim=1) + sc = (rm / 127.0).clamp_min(1e-8).to(torch.float16) + q = torch.clamp(torch.round(t32 / sc.float()[:, None]), -127, 127).to(torch.int8) + return q, sc + am = t32.abs().max().item() + sc = torch.tensor(am / 127.0 if am > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / sc.float()), -127, 127).to(torch.int8) + return q, sc + +def mixed_quantize(state_dict, int6_cats): + result, meta = {}, {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if t.is_floating_point() and FP16_PASSTHROUGH_PATTERNS and any(p in name for p in FP16_PASSTHROUGH_PATTERNS): + result[name] = t.to(torch.float16) + meta[name] = "passthrough_fp16" + continue + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q; result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_int8_per_row(t) + result[name + ".q"] = q; result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed(result, meta, template_sd): + out = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: continue + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig.dtype in (torch.float32, torch.bfloat16): + t = t.to(orig.dtype) + out[name] = t; continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig.dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig.dtype) + return out + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file): + hb = 256 * np.dtype(" 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: self._advance_file(); continue + k = min(remaining, avail); chunks.append(self.tokens[self.pos:self.pos + k]); self.pos += k; remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern, rank, world_size, device): + self.rank, self.world_size, self.device, self.stream = rank, world_size, device, TokenStream(pattern) + def next_batch(self, global_tokens, seq_len, grad_accum_steps): + lt = global_tokens // (self.world_size * grad_accum_steps); prs = lt + 1 + chunk = self.stream.take(prs * self.world_size); s = self.rank * prs + local = chunk[s:s + prs].to(dtype=torch.int64) + x, y = local[:-1].reshape(-1, seq_len), local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps=None): super().__init__(); self.eps = eps + def forward(self, x): return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +class CastedLinear(nn.Linear): + def forward(self, x): + w = self.weight.to(x.dtype) + if self.training and STE_QAT_ENABLED and w.ndim == 2: + with torch.no_grad(): + w32 = w.float(); rm = w32.abs().amax(dim=1).clamp_min(1e-8) + sc = rm / 31.0; wc = torch.clamp(w32, -rm[:, None], rm[:, None]) + wq = (torch.round(wc / sc[:, None]) * sc[:, None]).to(x.dtype) + w = w + (wq - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + +def restore_low_dim_params_to_fp32(module): + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + +class Rotary(nn.Module): + def __init__(self, dim, base=10000.0, train_seq_len=1024, rope_dims=0): + super().__init__() + self._base = base; self._train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + rd = self.rope_dims + self.register_buffer("inv_freq", 1.0 / (base ** (torch.arange(0, rd, 2, dtype=torch.float32) / rd)), persistent=False) + self._seq_len_cached = 0; self._cos_cached = None; self._sin_cached = None + def forward(self, seq_len, device, dtype): + if self._cos_cached is None or self._seq_len_cached != seq_len or self._cos_cached.device != device: + rd = self.rope_dims + if seq_len > self._train_seq_len: + scale = seq_len / self._train_seq_len + new_base = self._base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, None, :, :]; self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + +def apply_rotary_emb(x, cos, sin, rope_dims=0): + if rope_dims > 0 and rope_dims < x.size(-1): + xr, xp = x[..., :rope_dims], x[..., rope_dims:] + h = rope_dims // 2; x1, x2 = xr[..., :h], xr[..., h:] + xr = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((xr, xp), dim=-1) + h = x.size(-1) // 2; x1, x2 = x[..., :h], x[..., h:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_dims=0): + super().__init__() + self.num_heads, self.num_kv_heads = num_heads, num_kv_heads; self.head_dim = dim // num_heads + kv_dim = num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False); self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False); self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = rope_dims + self.rotary = Rotary(self.head_dim, base=rope_base, rope_dims=rope_dims) + self.use_xsa = False + + def _xsa_efficient(self, y, v): + B, T, H, D = y.shape + Hkv = v.size(2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x): + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True, enable_gqa=(self.num_kv_heads != self.num_heads)) + y = y.transpose(1, 2).contiguous() + if self.use_xsa: + v_bthd = v.transpose(1, 2).contiguous() + y = self._xsa_efficient(y, v_bthd) + return self.proj(y.reshape(bsz, seqlen, dim)) + +class MLP(nn.Module): + def __init__(self, dim, mlp_mult): + super().__init__(); hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False); self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x): return self.proj(torch.relu(self.fc(x)).square()) + +class SmearGate(nn.Module): + """Per-dimension SmearGate (from PR #194): each dim has its own blend ratio.""" + def __init__(self, dim): + super().__init__(); self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x): + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size, bigram_dim, model_dim): + super().__init__(); self.bvs = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim); nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def forward(self, token_ids): + t = token_ids.to(torch.int32); mod = self.bvs - 1 + out = torch.empty_like(t); out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + h = self.embed(out.long()) + if self.proj is not None: h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class Block(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, rope_dims=0, ln_sf=1.0): + super().__init__() + self.attn_norm, self.mlp_norm = RMSNorm(), RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_dims=rope_dims) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_sf = ln_sf + def forward(self, x, x0): + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + s = self.ln_sf + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * self.attn(self.attn_norm(x) * s) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * s) + return x + +class GPT(nn.Module): + def __init__(self, vocab_size, num_layers, model_dim, num_heads, num_kv_heads, mlp_mult, + tie_embeddings, tied_embed_init_std, logit_softcap, rope_base, qk_gain_init, + bigram_vocab_size=0, bigram_dim=128, xsa_last_n=0, rope_dims=0, ln_scale=False): + super().__init__() + self.tie_embeddings, self.tied_embed_init_std, self.logit_softcap = tie_embeddings, tied_embed_init_std, logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, rope_dims=rope_dims, ln_sf=1.0 / math.sqrt(i + 1) if ln_scale else 1.0) for i in range(num_layers)]) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: self.lm_head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + + def _init_weights(self): + if self.tie_embeddings: nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + nl = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): module.weight.mul_(1.0 / math.sqrt(2 * nl)) + + def _run_blocks(self, x, x0): + skips = [] + for i in range(self.num_encoder_layers): x = self.blocks[i](x, x0); skips.append(x) + for i in range(self.num_decoder_layers): + if skips: x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + return x + + def _embed(self, input_ids): + x = self.tok_emb(input_ids) + if self.bigram is not None: x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + return self.smear(x) + + def _logits(self, x): + if self.tie_embeddings: lp = F.linear(x, self.tok_emb.weight) + else: lp = self.lm_head(x) + return self.logit_softcap * torch.tanh(lp / self.logit_softcap) + + def forward(self, input_ids, target_ids): + x0 = self._embed(input_ids) + x = self.final_norm(self._run_blocks(x0, x0)) + x_flat = x.reshape(-1, x.size(-1)) + return F.cross_entropy(self._logits(x_flat).float(), target_ids.reshape(-1), reduction="mean") + + def forward_logits(self, input_ids): + x0 = self._embed(input_ids) + return self._logits(self.final_norm(self._run_blocks(x0, x0))) + +# ----------------------------- +# TRAINING +# ----------------------------- + +def _ttt_run(mdl, opt, epochs, rank, world_size, device, val_tokens, sl, batch_seqs): + nt = val_tokens.numel() - 1; ts = nt // sl + ms, me = (ts * rank) // world_size, (ts * (rank + 1)) // world_size + mdl.train() + for _ in range(epochs): + for bs in range(ms, me, batch_seqs): + be = min(bs + batch_seqs, me) + rs, re = bs * sl, be * sl + 1 + loc = val_tokens[rs:re].to(device=device, dtype=torch.int64, non_blocking=True) + x, y = loc[:-1].reshape(-1, sl), loc[1:].reshape(-1, sl) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): loss = mdl(x, y) + opt.zero_grad(set_to_none=True); loss.backward(); opt.step() + if dist.is_available() and dist.is_initialized(): dist.barrier() + +def ttt_adapt(args, mdl, rank, world_size, device, val_tokens, log0): + sl = args.train_seq_len + for p in mdl.parameters(): p.requires_grad_(False) + if args.ttt_p1_epochs > 0: + norm_params = [] + for n, p in mdl.named_parameters(): + if "norm" in n or "attn_scale" in n or "mlp_scale" in n or "resid_mix" in n or "skip_weight" in n: + p.requires_grad_(True); norm_params.append(p) + log0(f"ttt_p1: {args.ttt_p1_epochs}ep Adam lr={args.ttt_p1_lr} params={sum(p.numel() for p in norm_params)}") + opt1 = torch.optim.Adam(norm_params, lr=args.ttt_p1_lr) + torch.cuda.synchronize(); t0 = time.perf_counter() + _ttt_run(mdl, opt1, args.ttt_p1_epochs, rank, world_size, device, val_tokens, sl, args.ttt_batch_seqs) + torch.cuda.synchronize(); log0(f"ttt_p1: done in {1000.0 * (time.perf_counter() - t0):.0f}ms") + for p in norm_params: p.requires_grad_(False) + del opt1 + if args.ttt_epochs > 0: + nl = len(mdl.blocks) + for i, b in enumerate(mdl.blocks): + req = i >= args.ttt_freeze_blocks + for p in b.parameters(): p.requires_grad_(req) + trainable = [p for p in mdl.parameters() if p.requires_grad] + log0(f"ttt_p2: {args.ttt_epochs}ep SGD lr={args.ttt_lr} freeze={args.ttt_freeze_blocks} params={sum(p.numel() for p in trainable)}") + opt2 = torch.optim.SGD(trainable, lr=args.ttt_lr, momentum=args.ttt_momentum) + torch.cuda.synchronize(); t0 = time.perf_counter() + _ttt_run(mdl, opt2, args.ttt_epochs, rank, world_size, device, val_tokens, sl, args.ttt_batch_seqs) + torch.cuda.synchronize(); log0(f"ttt_p2: done in {1000.0 * (time.perf_counter() - t0):.0f}ms") + del opt2 + for p in mdl.parameters(): p.requires_grad_(False) + mdl.eval() + +def main(): + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8"); args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank, world_size = int(os.environ.get("RANK", "0")), int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + grad_accum_steps = 8 // world_size; grad_scale = 1.0 / grad_accum_steps + device = torch.device("cuda", local_rank); torch.cuda.set_device(device) + if distributed: dist.init_process_group(backend="nccl", device_id=device); dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True; torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False); enable_flash_sdp(True); enable_mem_efficient_sdp(False); enable_math_sdp(False) + logfile = None + if master_process: os.makedirs("logs", exist_ok=True); logfile = f"logs/{args.run_id}.txt"; print(logfile) + def log0(msg, console=True): + if not master_process: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + log0(code, console=False); log0("=" * 100, console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + log0("=" * 100, console=False) + random.seed(args.seed); np.random.seed(args.seed); torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:tokens:{val_tokens.numel() - 1}") + + base_model = GPT(vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, rope_dims=args.rope_dims, ln_scale=args.ln_scale).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [p for n, p in block_named_params if p.ndim == 2 and not any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = [p for n, p in block_named_params if p.ndim < 2 or any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + if base_model.skip_weights.numel() > 0: scalar_params.append(base_model.skip_weights) + for p in base_model.smear.parameters(): scalar_params.append(p) + if base_model.bigram is not None: + for n, p in base_model.bigram.named_parameters(): + if p.ndim == 2 and p.shape[0] >= 64 and p.shape[1] >= 64: matrix_params.append(p) + else: scalar_params.append(p) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.AdamW([{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=args.weight_decay, fused=True) + optimizer_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, weight_decay=args.muon_wd) + for group in optimizer_muon.param_groups: group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW([{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=args.weight_decay, fused=True) + optimizers = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizers.insert(1, torch.optim.AdamW([{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=args.weight_decay, fused=True)) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} swa:{args.swa_enabled} compression:{'zstd-22' if HAS_ZSTD else 'zlib-9'}") + log0(f"bigram_vocab:{args.bigram_vocab_size} bigram_dim:{args.bigram_dim} grad_clip:{args.grad_clip_norm} muon_wd:{args.muon_wd}") + log0(f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} warmdown:{args.warmdown_iters} seed:{args.seed}") + log0(f"ste_qat:{STE_QAT_ENABLED} ste_range:{STE_QAT_RANGE} int6_cats:{MIXED_QUANT_INT6_CATS}") + log0(f"fp16_passthrough:{FP16_PASSTHROUGH_PATTERNS}") + log0(f"xsa_last_n:{args.xsa_last_n} ema:{args.ema_enabled} ema_decay:{args.ema_decay} rope_dims:{args.rope_dims} ln_scale:{args.ln_scale}") + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all(): + for opt in optimizers: opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step, elapsed_ms): + if args.warmdown_iters <= 0: return 1.0 + if max_wallclock_ms is None: + ws = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if ws <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1); wms = args.warmdown_iters * step_ms + rms = max(max_wallclock_ms - elapsed_ms, 0.0) + return rms / max(wms, 1e-9) if rms <= wms else 1.0 + + if args.warmup_steps > 0: + init_sd = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers]; model.train() + for ws in range(args.warmup_steps): + zero_grad_all() + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = ms == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): wl = model(x, y) + (wl * grad_scale).backward() + for o in optimizers: o.step() + zero_grad_all() + if args.warmup_steps <= 20 or (ws + 1) % 10 == 0: log0(f"warmup_step:{ws + 1}/{args.warmup_steps}") + base_model.load_state_dict(init_sd, strict=True) + for o, s in zip(optimizers, init_opts, strict=True): o.load_state_dict(s) + zero_grad_all() + if distributed: model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + training_time_ms = 0.0; stop_after_step = None; swa_state = None; swa_count = 0 + ema_state = None + if args.ema_enabled: + ema_state = {n: t.detach().float().clone() for n, t in base_model.state_dict().items()} + torch.cuda.synchronize(); t0 = time.perf_counter(); step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + torch.cuda.synchronize(); training_time_ms += 1000.0 * (time.perf_counter() - t0) + vl, vb = eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vb:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms") + torch.cuda.synchronize(); t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms step:{step}/{args.iterations}") + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0); scale = lr_mul(step, elapsed_ms) + if args.swa_enabled and not args.ema_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + current = {k: v.detach().cpu().clone() for k, v in base_model.state_dict().items()} + if swa_state is None: + swa_state = current; swa_count = 1 + else: + inv = 1.0 / (swa_count + 1); keep = 1.0 - inv + for k, t in current.items(): + if torch.is_floating_point(swa_state[k]): swa_state[k].mul_(keep).add_(t, alpha=inv) + else: swa_state[k] = t + swa_count += 1 + zero_grad_all(); train_loss = torch.zeros((), device=device) + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = ms == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): loss = model(x, y) + train_loss += loss.detach(); (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + for group in optimizer_muon.param_groups: group["momentum"] = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for opt in optimizers: + for group in opt.param_groups: group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: opt.step() + if ema_state is not None: + d = args.ema_decay + with torch.no_grad(): + for n, t in base_model.state_dict().items(): + ema_state[n].mul_(d).add_(t.detach().float(), alpha=1.0 - d) + zero_grad_all(); step += 1 + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None): + log0(f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} train_time:{approx_ms:.0f}ms step_avg:{approx_ms / step:.2f}ms") + reached_cap = max_wallclock_ms is not None and approx_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + rct = torch.tensor(int(reached_cap), device=device); dist.all_reduce(rct, op=dist.ReduceOp.MAX); reached_cap = bool(rct.item()) + if stop_after_step is None and reached_cap: stop_after_step = step + + log0(f"peak memory: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") + if ema_state is not None: + log0("ema: applying EMA weights") + avg_sd = {n: t.to(dtype=base_model.state_dict()[n].dtype) for n, t in ema_state.items()} + base_model.load_state_dict(avg_sd, strict=True); del ema_state, avg_sd + elif swa_state is not None: + log0(f"swa: averaging {swa_count} checkpoints") + base_model.load_state_dict(swa_state, strict=True); del swa_state + + export_sd = base_model.state_dict() + if master_process: + torch.save(export_sd, "final_model.pt") + log0(f"Serialized model: {os.path.getsize('final_model.pt')} bytes Code: {len(code.encode('utf-8'))} bytes") + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + quant_result, quant_meta = mixed_quantize(sd_cpu, MIXED_QUANT_INT6_CATS) + qbuf = io.BytesIO(); torch.save({"w": quant_result, "m": quant_meta}, qbuf); qraw = qbuf.getvalue() + if HAS_ZSTD: qblob = zstd.ZstdCompressor(level=22).compress(qraw); cl = "zstd-22" + else: qblob = zlib.compress(qraw, level=9); cl = "zlib-9" + if master_process: + with open("final_model.int8.ptz", "wb") as f: f.write(qblob) + qfb = len(qblob); cb = len(code.encode("utf-8")) + log0(f"final_int8_zlib_roundtrip compressed_model_bytes:{qfb} code_bytes:{cb} total_artifact_bytes:{qfb + cb}") + log0(f"Serialized {cl}: {qfb} bytes Total: {qfb + cb} bytes") + if distributed: dist.barrier() + with open("final_model.int8.ptz", "rb") as f: qbd = f.read() + rd = zstd.ZstdDecompressor().decompress(qbd) if HAS_ZSTD else zlib.decompress(qbd) + qs = torch.load(io.BytesIO(rd), map_location="cpu") + deq_sd = dequantize_mixed(qs["w"], qs["m"], sd_cpu) + eval_model = GPT(vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, rope_dims=args.rope_dims, ln_scale=args.ln_scale).to(device).bfloat16() + for module in eval_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_sd, strict=True) + + if args.ttt_enabled: + log0(f"ttt: p1={args.ttt_p1_epochs}ep p2={args.ttt_epochs}ep freeze={args.ttt_freeze_blocks}") + torch.cuda.synchronize(); t_ttt = time.perf_counter() + ttt_adapt(args, eval_model, rank, world_size, device, val_tokens, log0) + torch.cuda.synchronize(); log0(f"ttt: total {1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + + torch.cuda.synchronize(); tqe = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + qvl, qvb = eval_val_sliding(args, eval_model, rank, world_size, device, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, stride=args.eval_stride, batch_seqs=args.eval_batch_seqs) + else: + qvl, qvb = eval_val(args, eval_model, rank, world_size, device, grad_accum_steps, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + torch.cuda.synchronize() + log0(f"final_roundtrip val_loss:{qvl:.4f} val_bpb:{qvb:.4f} eval_time:{1000.0 * (time.perf_counter() - tqe):.0f}ms") + log0(f"final_roundtrip_exact val_loss:{qvl:.8f} val_bpb:{qvb:.8f}") + if distributed: dist.destroy_process_group() + +if __name__ == "__main__": + main() + +==================================================================================================== +Mon Mar 23 01:47:24 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.95.05 Driver Version: 580.95.05 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 | +| N/A 33C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 30C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 | +| N/A 32C P0 122W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:99:00.0 Off | 0 | +| N/A 31C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | +| N/A 32C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:AF:00.0 Off | 0 | +| N/A 31C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:B3:00.0 Off | 0 | +| N/A 31C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:B7:00.0 Off | 0 | +| N/A 30C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 1 C /bin/dumb-init 1512MiB | +| 1 N/A N/A 1 C /bin/dumb-init 1512MiB | +| 2 N/A N/A 1 C /bin/dumb-init 1512MiB | +| 3 N/A N/A 1 C /bin/dumb-init 1512MiB | +| 4 N/A N/A 1 C /bin/dumb-init 1512MiB | +| 5 N/A N/A 1 C /bin/dumb-init 1512MiB | +| 6 N/A N/A 1 C /bin/dumb-init 1512MiB | +| 7 N/A N/A 1 C /bin/dumb-init 1512MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:tokens:62021632 +model_params:27092057 swa:False compression:zstd-22 +bigram_vocab:4096 bigram_dim:128 grad_clip:0.3 muon_wd:0.04 +train_batch_tokens:524288 train_seq_len:2048 warmdown:3000 seed:7 +ste_qat:False ste_range:31 int6_cats:frozenset({'mlp', 'tok_emb', 'attn'}) +fp16_passthrough:() +xsa_last_n:0 ema:True ema_decay:0.997 rope_dims:16 ln_scale:True +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9316 val_bpb:4.1053 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9332 train_time:225ms step_avg:224.95ms +step:2/20000 train_loss:8.8477 train_time:292ms step_avg:146.17ms +step:3/20000 train_loss:8.0149 train_time:365ms step_avg:121.62ms +step:4/20000 train_loss:7.2042 train_time:436ms step_avg:109.03ms +step:5/20000 train_loss:6.8607 train_time:509ms step_avg:101.81ms +step:6/20000 train_loss:7.6653 train_time:581ms step_avg:96.83ms +step:7/20000 train_loss:6.6026 train_time:651ms step_avg:93.01ms +step:8/20000 train_loss:6.5640 train_time:720ms step_avg:90.02ms +step:9/20000 train_loss:6.4115 train_time:790ms step_avg:87.75ms +step:10/20000 train_loss:6.1544 train_time:861ms step_avg:86.08ms +step:100/20000 train_loss:3.3140 train_time:7263ms step_avg:72.63ms +step:200/20000 train_loss:2.7892 train_time:14461ms step_avg:72.31ms +step:300/20000 train_loss:2.4113 train_time:21508ms step_avg:71.69ms +step:400/20000 train_loss:2.2793 train_time:28685ms step_avg:71.71ms +step:500/20000 train_loss:2.4253 train_time:35738ms step_avg:71.48ms +step:500/20000 val_loss:2.4242 val_bpb:1.4358 train_time:35744ms step_avg:71.49ms +step:600/20000 train_loss:2.4792 train_time:42899ms step_avg:71.50ms +step:700/20000 train_loss:2.3814 train_time:49961ms step_avg:71.37ms +step:800/20000 train_loss:2.2315 train_time:57146ms step_avg:71.43ms +step:900/20000 train_loss:2.2831 train_time:64218ms step_avg:71.35ms +step:1000/20000 train_loss:2.3323 train_time:71410ms step_avg:71.41ms +step:1000/20000 val_loss:2.2803 val_bpb:1.3505 train_time:71417ms step_avg:71.42ms +step:1100/20000 train_loss:2.1954 train_time:78560ms step_avg:71.42ms +step:1200/20000 train_loss:2.3496 train_time:85749ms step_avg:71.46ms +step:1300/20000 train_loss:2.3291 train_time:92822ms step_avg:71.40ms +step:1400/20000 train_loss:2.3777 train_time:100053ms step_avg:71.47ms +step:1500/20000 train_loss:2.1885 train_time:107106ms step_avg:71.40ms +step:1500/20000 val_loss:2.2297 val_bpb:1.3205 train_time:107112ms step_avg:71.41ms +step:1600/20000 train_loss:2.0482 train_time:114274ms step_avg:71.42ms +step:1700/20000 train_loss:2.1187 train_time:121318ms step_avg:71.36ms +step:1800/20000 train_loss:2.1545 train_time:128479ms step_avg:71.38ms +step:1900/20000 train_loss:2.1417 train_time:135532ms step_avg:71.33ms +step:2000/20000 train_loss:2.1914 train_time:142713ms step_avg:71.36ms +step:2000/20000 val_loss:2.1740 val_bpb:1.2876 train_time:142718ms step_avg:71.36ms +step:2100/20000 train_loss:2.2124 train_time:149866ms step_avg:71.36ms +step:2200/20000 train_loss:2.0082 train_time:156913ms step_avg:71.32ms +step:2300/20000 train_loss:2.3078 train_time:164117ms step_avg:71.36ms +step:2400/20000 train_loss:2.1308 train_time:171171ms step_avg:71.32ms +step:2500/20000 train_loss:2.0695 train_time:178345ms step_avg:71.34ms +step:2500/20000 val_loss:2.1455 val_bpb:1.2707 train_time:178349ms step_avg:71.34ms +step:2600/20000 train_loss:2.3702 train_time:185388ms step_avg:71.30ms +step:2700/20000 train_loss:2.0846 train_time:192570ms step_avg:71.32ms +step:2800/20000 train_loss:2.1748 train_time:199604ms step_avg:71.29ms +step:2900/20000 train_loss:2.1231 train_time:206784ms step_avg:71.30ms +step:3000/20000 train_loss:2.1647 train_time:213825ms step_avg:71.27ms +step:3000/20000 val_loss:2.1298 val_bpb:1.2614 train_time:213830ms step_avg:71.28ms +step:3100/20000 train_loss:2.1351 train_time:220966ms step_avg:71.28ms +step:3200/20000 train_loss:2.1280 train_time:228003ms step_avg:71.25ms +step:3300/20000 train_loss:2.1771 train_time:235166ms step_avg:71.26ms +step:3400/20000 train_loss:2.1030 train_time:242234ms step_avg:71.25ms +step:3500/20000 train_loss:2.1975 train_time:249504ms step_avg:71.29ms +step:3500/20000 val_loss:2.1202 val_bpb:1.2557 train_time:249509ms step_avg:71.29ms +step:3600/20000 train_loss:2.0442 train_time:256532ms step_avg:71.26ms +step:3700/20000 train_loss:2.0750 train_time:263719ms step_avg:71.28ms +step:3800/20000 train_loss:2.1444 train_time:270766ms step_avg:71.25ms +step:3900/20000 train_loss:1.9331 train_time:277919ms step_avg:71.26ms +step:4000/20000 train_loss:2.1227 train_time:284967ms step_avg:71.24ms +step:4000/20000 val_loss:2.1100 val_bpb:1.2496 train_time:284972ms step_avg:71.24ms +step:4100/20000 train_loss:2.1313 train_time:292165ms step_avg:71.26ms +step:4200/20000 train_loss:2.1134 train_time:299317ms step_avg:71.27ms +step:4300/20000 train_loss:1.9576 train_time:306377ms step_avg:71.25ms +step:4400/20000 train_loss:2.0534 train_time:313527ms step_avg:71.26ms +step:4500/20000 train_loss:2.2018 train_time:320577ms step_avg:71.24ms +step:4500/20000 val_loss:2.1059 val_bpb:1.2472 train_time:320582ms step_avg:71.24ms +step:4600/20000 train_loss:1.9130 train_time:327751ms step_avg:71.25ms +step:4700/20000 train_loss:2.2198 train_time:334790ms step_avg:71.23ms +step:4800/20000 train_loss:2.2014 train_time:341954ms step_avg:71.24ms +step:4900/20000 train_loss:2.1140 train_time:348996ms step_avg:71.22ms +step:5000/20000 train_loss:1.9615 train_time:356181ms step_avg:71.24ms +step:5000/20000 val_loss:2.1003 val_bpb:1.2439 train_time:356185ms step_avg:71.24ms +step:5100/20000 train_loss:1.9736 train_time:363192ms step_avg:71.21ms +step:5200/20000 train_loss:2.1232 train_time:370332ms step_avg:71.22ms +step:5300/20000 train_loss:2.1547 train_time:377368ms step_avg:71.20ms +step:5400/20000 train_loss:2.1406 train_time:384527ms step_avg:71.21ms +step:5500/20000 train_loss:2.0907 train_time:391594ms step_avg:71.20ms +step:5500/20000 val_loss:2.0962 val_bpb:1.2415 train_time:391599ms step_avg:71.20ms +step:5600/20000 train_loss:2.1262 train_time:398749ms step_avg:71.21ms +step:5700/20000 train_loss:2.1155 train_time:405793ms step_avg:71.19ms +step:5800/20000 train_loss:2.0729 train_time:412916ms step_avg:71.19ms +step:5900/20000 train_loss:2.0302 train_time:419966ms step_avg:71.18ms +step:6000/20000 train_loss:2.1519 train_time:427122ms step_avg:71.19ms +step:6000/20000 val_loss:2.0766 val_bpb:1.2299 train_time:427128ms step_avg:71.19ms +step:6100/20000 train_loss:2.0468 train_time:434165ms step_avg:71.17ms +step:6200/20000 train_loss:2.0170 train_time:441352ms step_avg:71.19ms +step:6300/20000 train_loss:1.9574 train_time:448488ms step_avg:71.19ms +step:6400/20000 train_loss:2.0902 train_time:455495ms step_avg:71.17ms +step:6500/20000 train_loss:2.0005 train_time:462656ms step_avg:71.18ms +step:6500/20000 val_loss:2.0535 val_bpb:1.2162 train_time:462663ms step_avg:71.18ms +step:6600/20000 train_loss:2.0357 train_time:469693ms step_avg:71.17ms +step:6700/20000 train_loss:2.0660 train_time:476844ms step_avg:71.17ms +step:6800/20000 train_loss:2.0844 train_time:483868ms step_avg:71.16ms +step:6900/20000 train_loss:2.0025 train_time:491041ms step_avg:71.17ms +step:7000/20000 train_loss:2.1280 train_time:498064ms step_avg:71.15ms +step:7000/20000 val_loss:2.0285 val_bpb:1.2014 train_time:498071ms step_avg:71.15ms +step:7100/20000 train_loss:1.9576 train_time:505230ms step_avg:71.16ms +step:7200/20000 train_loss:2.0894 train_time:512276ms step_avg:71.15ms +step:7300/20000 train_loss:1.9834 train_time:519403ms step_avg:71.15ms +step:7400/20000 train_loss:2.0070 train_time:526452ms step_avg:71.14ms +step:7500/20000 train_loss:1.9935 train_time:533627ms step_avg:71.15ms +step:7500/20000 val_loss:2.0001 val_bpb:1.1846 train_time:533631ms step_avg:71.15ms +step:7600/20000 train_loss:1.8687 train_time:540658ms step_avg:71.14ms +step:7700/20000 train_loss:1.9450 train_time:547839ms step_avg:71.15ms +step:7800/20000 train_loss:2.0092 train_time:554902ms step_avg:71.14ms +step:7900/20000 train_loss:1.9825 train_time:562061ms step_avg:71.15ms +step:8000/20000 train_loss:1.9669 train_time:569106ms step_avg:71.14ms +step:8000/20000 val_loss:1.9684 val_bpb:1.1658 train_time:569111ms step_avg:71.14ms +step:8100/20000 train_loss:1.9936 train_time:576274ms step_avg:71.14ms +step:8200/20000 train_loss:2.0323 train_time:583312ms step_avg:71.14ms +step:8300/20000 train_loss:1.9498 train_time:590489ms step_avg:71.14ms +step:8400/20000 train_loss:1.9599 train_time:597868ms step_avg:71.17ms +step:8428/20000 val_loss:1.9449 val_bpb:1.1519 train_time:599841ms step_avg:71.17ms +stopping_early: wallclock_cap train_time:599841ms step:8428/20000 +peak memory: 14059 MiB +ema: applying EMA weights +Serialized model: 106313663 bytes Code: 49961 bytes +final_int8_zlib_roundtrip compressed_model_bytes:15779229 code_bytes:49961 total_artifact_bytes:15829190 +Serialized zstd-22: 15779229 bytes Total: 15829190 bytes +ttt: p1=0ep p2=25ep freeze=0 +ttt_p2: 25ep SGD lr=0.012 freeze=0 params=25974872 +ttt_p2: done in 388031ms +ttt: total 388035ms +final_eval_mode:sliding_window stride:64 batch_seqs:32 +final_roundtrip val_loss:1.9002 val_bpb:1.1254 eval_time:186653ms +final_roundtrip_exact val_loss:1.90018795 val_bpb:1.12540132