diff --git a/records/track_non_record_16mb/2026-04-07_Codebooks/README.md b/records/track_non_record_16mb/2026-04-07_Codebooks/README.md new file mode 100644 index 0000000000..b537c69736 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-07_Codebooks/README.md @@ -0,0 +1,68 @@ +# Non-Record: Codebooks! - val_bpb 1.2067 (3-seed mean) + +n.b. This is not a competitive record submission, but it was done under record conditions and hopefully will make its way into a leaderboard submission at some point! + +**val bpb: 1.20667** (3-seed mean, std=0.00365) + +| Seed | Steps | Pre-quant BPB | Post-quant BPB | **Sliding BPB** | Artifact | +|-|-|-|-|-|-| +| 42 | 4822 | 1.10450 | 1.2280 | **1.211** | 15863950 | +| 1024 | 4826 | 1.10397 | 1.21940 | **1.20207** | 15881168 | +| 1337 | 4866 | 1.10417 | 1.22427| **1.20694** | 15859963 | +| **Mean** | 4838 | 1.10421 | 1.22389 | **1.20667** | 15868360 | + +I've been back for a day or two and have been messing about with VQ/codebook approaches; it seems like the competition is dying down a bit so I thought I'd do a little write-up since my approach is only half-working at the moment, for the benefit of anyone else interested in this line of work. Putting together a record submission at this point requires a bunch of systems/TTT stuff anyway that I don't want in this PR, even if I could get the quant gap down to something competitive. + +In general, the motivation for trying codebooks is that vector quantization may be able to get us under the int6 limit for MLP/attn weights we've all been running into (although practically competitive submissions after compression are already around ~3.5 bpw) and get down to 1-3 bits while still training and optimizing in healthier datatypes. Codebooks are certainly the most powerful mode of compression, if you know what codes to use, and that's downstream of knowing more about our model's structure than Brotli/LZMA does. Unfortunately I'm not there yet - while I can get to around ~1.20 bpb in competition conditions with this setup, and I can squeeze another 2 layers in, I can't close the quant gap. I do want to work a little harder on this over the next few weeks, but I'm going to do some systems work elsewhere first because I wanna learn CuTEDSL. + +Below is a rundown of my discoveries and what I'm sticking with: + +## EP8 Lattice Fixed Codebook + +I took this from the [QuIP# paper](https://github.com/Cornell-RelaxML/quip-sharp/tree/main), which was one of several, together with [AQLM](https://arxiv.org/abs/2401.06118) and [VPTQ](https://arxiv.org/abs/2409.17066), that I've taken some inspiration from. In our environment there's huge upside to having a fixed codebook, since then we don't need to store the codebook and can save 1-2MB. In particular, this codebook is the most dense 8D spherical lattice we have, and so it should be great for Gaussian 8D vectors. We can block up our weights into 8D blocks and then store 16-bit indices, for a total of 2.0bpw, and add another 8-bit scale vector for 3.0 bpw in the codebook. Pushing the scale vector lower than 8-bits tends to damage things significantly. + +Confidence: 8/10 likely not going to be able to use learned codebooks due to size limits + +## Hadamard Transform + +This was the other part of QuIP#, essentially applying a random sign-flip+rotation to the blocked weights, which is meant to make them more isotropic and iid Gaussian. Weirdly I didn't find this worked as well as they said it would, and I think that's because the model weights are already pretty isotropic. However, it may confer some small benefit on the order of 0.002 bpb, so it stays. Confidence: 5/10 + +## Hessian-aware Assignment + Scales + +This was definitely the best thing I did, since it took my codebooks from not working at all to mostly working; I used the GPTQ machinery already in the baseline and repurposed `collect_hessians` to produce metrics by which to select the codebook index and scales. This was dramatically better than Euclidean distance at `val_bpb`, which is understandable - as many have noted, raw MSE does not necessarily imply you've captured downstream performance, and this allows us to pick the codebook compression that is least damaging to the weight's role in the loss, similar to GPTQ. Confidence: 7/10 + +## Lightweight Codebook Penalties + +Unfortunately, while I would really like to do QAT with this setup, and force the model to get used to passing information through the codebook, it's painfully slow - the Hadamard part is relatively fast, but materializing the codebook and doing the assignment above is very time-consuming, and there's the usual problem with VQ where it, owing to being discrete, doesn't have an obvious backwards pass and must use STE or other hacks. Since we're in such a compute-constrained regime on the record board, I have to settle for proxies to QAT, and indeed QAT hasn't worked great so far in the other record entries. I might do a non-record submission with super-long step times where I can do codebook quantization in the forward pass soon. + +For now, I simply run an approximate version of the codebook quantization every 16 steps, and then have an auxiliary L2 loss that should force weights close to their codebook counterparts, which I turn on at the end of training. This is at least intended to give the model and optimizer a heads up about quantization and to begin to prioritize. I tried some cooler ideas but they worked about as well; again I think doing full QAT would be ideal. I tried KL-distillation with the quantized model as a student, but that ate time we don't have. + +## Outlier Paths + +One gimme is always to provide a route around quantization for particularly difficult tensors; I had about 700kb left, so I decided to simply allow that to fallback the tensors with the worst Hessian-derived reconstruction error to int8. This earned me back a tiny bit of bpb, but nothing major. Confidence: 8/10, lameness 10/10. + +## Reject Bin + +Some things I tried that didn't work: + +Multiple codebooks sound like an absolutely awesome idea (I love the [AQLM paper](https://arxiv.org/abs/2401.06118)), but I found them hard to optimize, particularly codebooks intended to store residual corrections. AQLM itself has some really gnarly stuff, since you're solving this joint optimization of multiple discrete objects. They also take up a lot of space. I think doing some kind of hierarchical/residual/additive codebooks thing would be cool, but I need to figure out why this codebook isn't working great first before adding another. + +Shared codebooks: One idea that sounds great is storing one codebook for MLP and one for Attn, but obviously that requires storing 2 codebooks, which wastes space. The sharing worked well, but it worked well enough to justify using a shared codebook between all tensors. + +Learning codebooks in general; again, since these are discrete clusterings, we can't really use gradient descent and so commonly people use k-means; this takes a lot of time since it's not really accelerated by GPUs, and doesn't let you optimize codebooks for our downstream goal of compression, which is what we want. Ultimately we have a choice over a) what the entries in the codebook are, b) which index to pick, and we only need to optimize one at once. + +Entropy-weighted assignment: Tried various gambits to encourage the model to reuse codes when it could; this worked and decreased compressed storage size as expected, but damaged performance more. + +Mega-bitcrushed scales: as expected, going down below 2 BPW with this setup produced completely incoherent models, which makes sense, these are not bitnets. + +Voronoi auxiliary loss: I had this idea where, if I had some kind of loss that punished being on the boundary between cells of the codebook, it would encourage regularization; it kind of worked, but not as well as the simpler L2 auxiliary loss described above. + +Snapping: I'm a huge proponent of doing dumb stuff first, so I tried just snapping the model to the quantized vector locations every few steps while training. This actually worked surprisingly well, better than many of the gigabrain methods I tried, but not as well as the L2 loss described above. + +## Conclusion + +I'm a little miffed I wasn't able to reduce the quantization gap further; the raw model before quantization with 13 layers is certainly able to be top of the leaderboard, without any TTT, so the challenge is just fitting the codebook structure to the model. As I said, I suspect that better QAT strategies may be the secret to unlocking codebooks at a competitive level. + +I am proud of how my setup in this project gives me very fine-grained control over where in the model to spend bytes and therefore entropy; the codebook allows you to know BPW ahead of time, and allocate more or less to embedding, more or less to norms vs directions, etc. I think with a better understanding (or knowledge via more regularization) of the latent space it should be possible to design a codebook for the way this particular parameter-golf family of models learns. + +As I said, I think the competition is dying down, but there is still plenty of meat on the bones of the record leaderboard; but from now on most of the gain will come from systems work to get more data out of the limited compute than ML about compression. With that in mind, I'm excited to get to work on some kernels. \ No newline at end of file diff --git a/records/track_non_record_16mb/2026-04-07_Codebooks/submission.json b/records/track_non_record_16mb/2026-04-07_Codebooks/submission.json new file mode 100644 index 0000000000..cd8760e427 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-07_Codebooks/submission.json @@ -0,0 +1,47 @@ +{ + "author": "Spruce Campbell", + "github_id": "mtybadger", + "name": "Codebooks", + "blurb": "Codebooks (VQ/codebook approach under record conditions; see README for details).", + "date": "2026-04-07", + "track": "track_non_record_16mb", + "val_loss": 2.8116260, + "val_bpb": 1.20667, + "val_loss_std": 0.00857523, + "val_bpb_std": 0.00368035, + "seeds": [42, 1024, 1337], + "seed_results": { + "42": { + "val_loss": 2.82183456, + "val_bpb": 1.21108460, + "artifact_bytes": 15863950, + "steps": 4822, + "step_avg_ms": null + }, + "1337": { + "val_loss": 2.81219097, + "val_bpb": 1.20694573, + "artifact_bytes": 15859963, + "steps": 4866, + "step_avg_ms": null + }, + "1024": { + "val_loss": 2.80085242, + "val_bpb": 1.20207941, + "artifact_bytes": 15881168, + "steps": 4826, + "step_avg_ms": null + } + }, + "comparison_baseline_pr": 1218, + "artifact_bytes_mean": 15868360, + "artifact_bytes_max": 15881168, + "bytes_total": 15881168, + "train_steps_mean": 4838.00, + "step_avg_ms_mean": null, + "hardware": "8xH100 80GB SXM", + "pytorch_version": "2.9.1+cu128", + "cuda_version": "12.8", + "flash_attn_version": "2.8.3 (FA3 Hopper kernels)" + } + \ No newline at end of file diff --git a/records/track_non_record_16mb/2026-04-07_Codebooks/train_gpt.py b/records/track_non_record_16mb/2026-04-07_Codebooks/train_gpt.py new file mode 100644 index 0000000000..bdc9937319 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-07_Codebooks/train_gpt.py @@ -0,0 +1,2050 @@ +import copy +import glob +import io +import lzma +import math +import os +from pathlib import Path +import random +import subprocess +import sys +import time +import uuid + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +from torch import Tensor, nn + +from flash_attn_interface import flash_attn_func as flash_attn_3_func + +# ---------------------------------------- +# Hyperparameters +# ---------------------------------------- + +class Hyperparameters(): + # Experiment settings + data_dir = os.environ.get('DATA_DIR', './data/') + seed = int(os.environ.get('SEED', 1337)) + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + + # Training length + iterations = int(os.environ.get('ITERATIONS', 20000)) + warmdown_frac = float(os.environ.get('WARMDOWN_FRAC', 0.667)) + warmup_steps = int(os.environ.get('WARMUP_STEPS', 20)) + train_batch_tokens = int(os.environ.get('TRAIN_BATCH_TOKENS', 2048 * 48 * 8)) + train_seq_len = int(os.environ.get('TRAIN_SEQ_LEN', 2048)) + eval_seq_len = int(os.environ.get('EVAL_SEQ_LEN', 2048)) + max_wallclock_seconds = float(os.environ.get('MAX_WALLCLOCK_SECONDS', 600.0)) + train_log_every = int(os.environ.get('TRAIN_LOG_EVERY', 500)) + + # Validation/Evals + val_batch_tokens = int(os.environ.get('VAL_BATCH_TOKENS', 2048 * 32 * 8)) + val_loss_every = int(os.environ.get('VAL_LOSS_EVERY', 4000)) + sliding_window_enabled = bool(int(os.environ.get('SLIDING_WINDOW_ENABLED', '1'))) + + # Model architecture + vocab_size = int(os.environ.get('VOCAB_SIZE', 4096)) + num_layers = int(os.environ.get('NUM_LAYERS', 11)) + xsa_last_n = int(os.environ.get('XSA_LAST_N', 11)) + num_kv_heads = int(os.environ.get('NUM_KV_HEADS', 4)) + model_dim = int(os.environ.get('MODEL_DIM', 512)) + embedding_dim = int(os.environ.get('EMBEDDING_DIM', 512)) + num_heads = int(os.environ.get('NUM_HEADS', 8)) + mlp_mult = float(os.environ.get('MLP_MULT', 4.0)) + skip_gates_enabled = bool(int(os.environ.get('SKIP_GATES_ENABLED', '1'))) + tie_embeddings = bool(int(os.environ.get('TIE_EMBEDDINGS', '1'))) + logit_softcap = float(os.environ.get('LOGIT_SOFTCAP', 30.0)) + rope_base = float(os.environ.get('ROPE_BASE', 10000.0)) + rope_dims = int(os.environ.get('ROPE_DIMS', 16)) + rope_train_seq_len = int(os.environ.get('ROPE_TRAIN_SEQ_LEN', 2048)) + ln_scale = bool(int(os.environ.get('LN_SCALE', '1'))) + ve_enabled = bool(int(os.environ.get('VE_ENABLED', '1'))) + ve_dim = int(os.environ.get('VE_DIM', 128)) + ve_layers = os.environ.get('VE_LAYERS', '9,10') + qk_gain_init = float(os.environ.get('QK_GAIN_INIT', 4.0)) + + # Optimizer + min_lr = float(os.environ.get('MIN_LR', 0.0)) + embed_lr = float(os.environ.get('EMBED_LR', 0.6)) + head_lr = float(os.environ.get('HEAD_LR', 0.008)) + tied_embed_lr = float(os.environ.get('TIED_EMBED_LR', 0.03)) + tied_embed_init_std = float(os.environ.get('TIED_EMBED_INIT_STD', 0.005)) + matrix_lr = float(os.environ.get('MATRIX_LR', 0.02)) + scalar_lr = float(os.environ.get('SCALAR_LR', 0.02)) + muon_momentum = float(os.environ.get('MUON_MOMENTUM', 0.99)) + muon_backend_steps = int(os.environ.get('MUON_BACKEND_STEPS', 5)) + muon_momentum_warmup_start = float(os.environ.get('MUON_MOMENTUM_WARMUP_START', 0.92)) + muon_momentum_warmup_steps = int(os.environ.get('MUON_MOMENTUM_WARMUP_STEPS', 1500)) + beta1 = float(os.environ.get('BETA1', 0.9)) + beta2 = float(os.environ.get('BETA2', 0.95)) + adam_eps = float(os.environ.get('ADAM_EPS', 1e-8)) + grad_clip_norm = float(os.environ.get('GRAD_CLIP_NORM', 0.3)) + eval_stride = int(os.environ.get('EVAL_STRIDE', 64)) + muon_beta2 = float(os.environ.get('MUON_BETA2', 0.95)) + adam_wd = float(os.environ.get('ADAM_WD', 0.02)) + muon_wd = float(os.environ.get('MUON_WD', 0.085)) + embed_wd = float(os.environ.get('EMBED_WD', 0.085)) + ema_decay = float(os.environ.get('EMA_DECAY', 0.997)) + + # Compression + compressor = os.environ.get('COMPRESSOR', 'brotli') #(lzma or brotli) + codebook_block_dim = int(os.environ.get("CODEBOOK_BLOCK_DIM", 8)) + codebook_use_hadamard = bool(int(os.environ.get("CODEBOOK_USE_HADAMARD", "1"))) + codebook_calibration_batches = int(os.environ.get("CODEBOOK_CALIBRATION_BATCHES", 64)) + codebook_reserve_seconds = float(os.environ.get("CODEBOOK_RESERVE_SECONDS", 10.0)) + codebook_hessian_damp = float(os.environ.get("CODEBOOK_HESSIAN_DAMP", 0.01)) + codebook_lattice_scale = float(os.environ.get("CODEBOOK_LATTICE_SCALE", 1.03)) + + # Distributed setup + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + is_main_process = rank == 0 + grad_accum_steps = 8 // world_size + + # Data paths + datasets_dir = os.path.join(data_dir, 'datasets', f'fineweb10B_sp{vocab_size}') + train_files = os.path.join(datasets_dir, 'fineweb_train_*.bin') + val_files = os.path.join(datasets_dir, 'fineweb_val_*.bin') + tokenizer_path = os.path.join(data_dir, 'tokenizers', f'fineweb_{vocab_size}_bpe.model') + + # Experiment files + logfile = f"logs/{run_id}.txt" + model_path = "final_model.pt" + quantized_model_path = "final_model.codebook.ptz" + +# ---------------------------------------- +# Global Logging Function +# ---------------------------------------- + +_logger_hparams = None + + +def set_logging_hparams(h: Hyperparameters) -> None: + global _logger_hparams + _logger_hparams = h + + +def log(msg, console: bool = True) -> None: + if _logger_hparams is None: + print(msg) + if _logger_hparams.is_main_process: + if console: + print(msg) + if _logger_hparams.logfile is not None: + with open(_logger_hparams.logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + +# ---------------------------------------- +# Data Loading +# ---------------------------------------- + +class ValidationData: + def __init__(self, h: Hyperparameters, device: torch.device): + if not h.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {h.tokenizer_path}") + self.sp = spm.SentencePieceProcessor(model_file=h.tokenizer_path) + if int(self.sp.vocab_size()) != h.vocab_size: + raise ValueError( + f"VOCAB_SIZE={h.vocab_size} does not match tokenizer vocab_size={int(self.sp.vocab_size())}" + ) + + self.val_tokens = load_validation_tokens(h.val_files, h.eval_seq_len) + self.base_bytes_lut, self.has_leading_space_lut, self.is_boundary_token_lut = ( + build_sentencepiece_luts(self.sp, h.vocab_size, device)) + + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + # The BPB calculation assumes "▁" is its own token so that leading-space bytes + # are counted correctly. See https://github.com/openai/parameter-golf/issues/897 + assert sp.piece_to_id("\u2581") != sp.unk_id(), \ + "Tokenizer must have '▁' (space) as its own token for correct BPB byte counting" + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" int: + key = str(file) + cached = _SHARD_NTOKENS_CACHE.get(key) + if cached is not None: + return cached + header = np.fromfile(file, dtype=" np.memmap: + key = str(file) + mm = _MMAP_CACHE.get(key) + if mm is not None: + return mm + n = _read_num_tokens(file) + mm = np.memmap(file, mode="r", dtype=" int: + if n <= 1: + return 1 + while True: + s = int(self._rng.integers(1, n)) + if math.gcd(s, n) == 1: + return s + + def _reset_cursor(self, si: int, seq_len: int) -> None: + nt = int(self._num_tokens[si]) + max_phase = min(seq_len - 1, max(0, nt - seq_len - 1)) + phase = int(self._rng.integers(max_phase + 1)) if max_phase > 0 else 0 + bc = (nt - 1 - phase) // seq_len + self._cursor_phase[si] = phase + self._cursor_block_count[si] = bc + self._cursor_next[si] = 0 + self._cursor_start[si] = int(self._rng.integers(bc)) if bc > 1 else 0 + self._cursor_stride[si] = self._pick_coprime_stride(bc) + self._cursor_init[si] = True + + def _ensure_cursor(self, si: int, seq_len: int) -> None: + if not self._cursor_init[si] or self._cursor_next[si] >= self._cursor_block_count[si]: + self._reset_cursor(si, seq_len) + + def _take_from_shard(self, si: int, seq_len: int, count: int, out: list[tuple[int, int]]) -> None: + rem = count + while rem > 0: + self._ensure_cursor(si, seq_len) + bc = int(self._cursor_block_count[si]) + ni = int(self._cursor_next[si]) + take = min(rem, bc - ni) + phase = int(self._cursor_phase[si]) + start = int(self._cursor_start[si]) + stride = int(self._cursor_stride[si]) + for j in range(take): + bi = (start + (ni + j) * stride) % bc + out.append((si, phase + bi * seq_len)) + self._cursor_next[si] = ni + take + rem -= take + + def _init_pipeline(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> None: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + num_seqs = local_tokens // seq_len + global_num_seqs = num_seqs * self.world_size + self._cfg = (local_tokens, seq_len, num_seqs, global_num_seqs) + bbc = (self._num_tokens - 1) // seq_len + eligible = bbc > 0 + self._eligible_shards = np.nonzero(eligible)[0].astype(np.int64) + self._base_block_counts = bbc[self._eligible_shards].astype(np.int64) + + def _sample_global_windows(self) -> list[tuple[int, int]]: + assert self._cfg is not None and self._eligible_shards is not None + _, seq_len, _, gns = self._cfg + ec = int(self._eligible_shards.size) + progress = min(self._batches_built / 1800.0, 1.0) + remaining = np.empty(ec, dtype=np.float64) + for i, si in enumerate(self._eligible_shards.tolist()): + if self._cursor_init[si]: + r = int(self._cursor_block_count[si]) - int(self._cursor_next[si]) + remaining[i] = float(max(r, 1)) + else: + remaining[i] = float(self._base_block_counts[i]) + alpha = 0.90 - 0.40 * progress + weights = np.power(remaining, alpha) + ws = float(weights.sum()) + if not np.isfinite(ws) or ws <= 0.0: + weights = np.ones(ec, dtype=np.float64) + ws = float(weights.sum()) + probs = weights / ws + low = min(max(8, self.world_size), ec, gns) + high = min(max(32, self.world_size * 8), ec, gns) + mix = max(1, min(int(round(low + progress * (high - low))), ec, gns)) + cp = self._rng.choice(ec, size=mix, replace=False, p=probs) + cs = self._eligible_shards[cp] + cpr = probs[cp].copy() + cpr /= cpr.sum() + counts = np.ones(mix, dtype=np.int64) + extra = gns - mix + if extra > 0: + counts += self._rng.multinomial(extra, cpr).astype(np.int64) + perm = self._rng.permutation(mix) + cs, counts = cs[perm], counts[perm] + buckets: list[list[tuple[int, int]]] = [] + for si, cnt in zip(cs.tolist(), counts.tolist()): + b: list[tuple[int, int]] = [] + self._take_from_shard(int(si), seq_len, int(cnt), b) + if b: + if len(b) > 1: + bp = self._rng.permutation(len(b)) + b = [b[int(k)] for k in bp.tolist()] + buckets.append(b) + windows: list[tuple[int, int]] = [] + active = [i for i, bk in enumerate(buckets) if bk] + while active: + order = self._rng.permutation(len(active)) + new_active: list[int] = [] + for oi in order.tolist(): + bi = active[oi] + if buckets[bi]: + windows.append(buckets[bi].pop()) + if buckets[bi]: + new_active.append(bi) + active = new_active + return windows + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + if self._cfg is None: + self._init_pipeline(global_tokens, seq_len, grad_accum_steps) + _, _, num_seqs, _ = self._cfg + gw = self._sample_global_windows() + local_w = gw[self.rank::self.world_size] + x = torch.empty((num_seqs, seq_len), dtype=torch.int64) + y = torch.empty((num_seqs, seq_len), dtype=torch.int64) + for slot, (si, pos) in enumerate(local_w): + mm = _get_shard_memmap(self.files[si]) + window = torch.as_tensor(np.array(mm[pos:pos + seq_len + 1], dtype=np.int64)) + x[slot] = window[:-1] + y[slot] = window[1:] + self._batches_built += 1 + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ---------------------------------------- +# Model Architecture +# ---------------------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, + rope_base: float, qk_gain_init: float, train_seq_len: int): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len) + self.use_xsa = False + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + + +class ValueEmbedding(nn.Module): + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + return self.proj(F.leaky_relu(self.fc(x), negative_slope=0.5).square()) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, + rope_base: float, qk_gain_init: float, train_seq_len: int, + layer_idx: int = 0, ln_scale: bool = False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + return x_out + + +class GPT(nn.Module): + def __init__(self, h: Hyperparameters): + super().__init__() + self._ve_target_dim = h.num_kv_heads * (h.model_dim // h.num_heads) + if h.logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {h.logit_softcap}") + self.tie_embeddings = h.tie_embeddings + self.tied_embed_init_std = h.tied_embed_init_std + self.logit_softcap = h.logit_softcap + self.tok_emb = nn.Embedding(h.vocab_size, h.embedding_dim) + if h.embedding_dim != h.model_dim: + self.embed_proj = CastedLinear(h.embedding_dim, h.model_dim, bias=False) + self.head_proj = CastedLinear(h.model_dim, h.embedding_dim, bias=False) + else: + self.embed_proj = None + self.head_proj = None + self.num_encoder_layers = h.num_layers // 2 + self.num_decoder_layers = h.num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, h.model_dim, dtype=torch.float32)) + self.skip_gates = nn.Parameter(torch.zeros(self.num_skip_weights, h.model_dim, dtype=torch.float32)) if h.skip_gates_enabled else None + self.blocks = nn.ModuleList([ + Block(h.model_dim, h.num_heads, h.num_kv_heads, h.mlp_mult, h.rope_base, + h.qk_gain_init, h.train_seq_len, layer_idx=i, ln_scale=h.ln_scale) + for i in range(h.num_layers) + ]) + if h.rope_dims > 0: + head_dim = h.model_dim // h.num_heads + for block in self.blocks: + block.attn.rope_dims = h.rope_dims + block.attn.rotary = Rotary(head_dim, base=h.rope_base, train_seq_len=h.train_seq_len, rope_dims=h.rope_dims) + self.ve_layer_indices = [int(x) for x in h.ve_layers.split(",") if x.strip()] if h.ve_enabled else [] + kv_dim = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(h.vocab_size, h.ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() + self.final_norm = RMSNorm() + self.lm_head = None if h.tie_embeddings else CastedLinear(h.embedding_dim, h.vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + if h.xsa_last_n > 0: + for i in range(max(0, h.num_layers - h.xsa_last_n), h.num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + if self.embed_proj is not None: + x = self.embed_proj(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + scaled_skip = self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[i].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + x = self.final_norm(x) + if self.head_proj is not None: + x = self.head_proj(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self.forward_logits(input_ids) + return F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1), reduction="mean") + +# ---------------------------------------- +# Optimization +# ---------------------------------------- + +@torch.compile +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +class Optimizers(): + def __init__(self, h: Hyperparameters, base_model: GPT): + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in + CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in + CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.skip_gates is not None and base_model.skip_gates.numel() > 0: + scalar_params.append(base_model.skip_gates) + + token_lr = h.tied_embed_lr if h.tie_embeddings else h.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + + self.optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.embed_wd, + fused=True, + ) + self.optimizer_muon = Muon( + matrix_params, + lr=h.matrix_lr, + momentum=h.muon_momentum, + backend_steps=h.muon_backend_steps, + weight_decay=h.muon_wd, + ) + for group in self.optimizer_muon.param_groups: + group["base_lr"] = h.matrix_lr + self.optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": h.scalar_lr, "base_lr": h.scalar_lr}], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.adam_wd, + fused=True, + ) + self.optimizers: list[torch.optim.Optimizer] = [self.optimizer_tok, self.optimizer_muon, self.optimizer_scalar] + if base_model.lm_head is not None: + self.optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": h.head_lr, "base_lr": h.head_lr}], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + fused=True, + ) + self.optimizers.insert(1, self.optimizer_head) + else: + self.optimizer_head = None + + def __iter__(self): + return iter(self.optimizers) + + def zero_grad_all(self) -> None: + for opt in self.optimizers: + opt.zero_grad(set_to_none=True) + + def step(self): + for opt in self.optimizers: + opt.step() + self.zero_grad_all() + +# ---------------------------------------- +# Quantization +# ---------------------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates,ve_layer_scales,ve_shared.scale", + ).split(",") + if pattern +) +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +CODEBOOK_TARGET_PATTERNS = ( + "attn.c_q.weight", + "attn.c_k.weight", + "attn.c_v.weight", + "attn.proj.weight", + "mlp.fc.weight", + "mlp.proj.weight", +) + + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def dequantize_int8_rows(q: Tensor, s: Tensor, *, device: torch.device, dtype: torch.dtype) -> Tensor: + qf = q.to(device=device, dtype=torch.float32) + sf = s.to(device=device, dtype=torch.float32) + if sf.ndim > 0: + return (qf * sf.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype) + return (qf * float(sf.item())).to(dtype=dtype) + + +def restore_fp32_params(model: nn.Module) -> None: + """After .bfloat16(), restore CastedLinear weights and control params to FP32.""" + for module in model.modules(): + if isinstance(module, CastedLinear): + module.float() + for name, param in model.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +def _is_power_of_two(n: int) -> bool: + return n > 0 and (n & (n - 1)) == 0 + + +def _validate_codebook_hparams(h: Hyperparameters) -> None: + if not _is_power_of_two(h.codebook_block_dim): + raise ValueError(f"CODEBOOK_BLOCK_DIM must be a power of 2, got {h.codebook_block_dim}") + if h.codebook_block_dim != 8: + raise ValueError(f"This E8P baseline expects CODEBOOK_BLOCK_DIM=8, got {h.codebook_block_dim}") + if h.codebook_calibration_batches < 0: + raise ValueError( + f"CODEBOOK_CALIBRATION_BATCHES must be non-negative, got {h.codebook_calibration_batches}" + ) + if h.codebook_reserve_seconds < 0: + raise ValueError( + f"CODEBOOK_RESERVE_SECONDS must be non-negative, got {h.codebook_reserve_seconds}" + ) + if h.codebook_hessian_damp < 0: + raise ValueError( + f"CODEBOOK_HESSIAN_DAMP must be non-negative, got {h.codebook_hessian_damp}" + ) + if h.codebook_lattice_scale <= 0: + raise ValueError(f"CODEBOOK_LATTICE_SCALE must be positive, got {h.codebook_lattice_scale}") + + +def _blockify_weight(t: Tensor, block_dim: int) -> tuple[Tensor, tuple[int, int]]: + if t.ndim != 2: + raise ValueError(f"Codebook quantization expects a 2D tensor, got shape {tuple(t.shape)}") + if t.shape[1] % block_dim != 0: + raise ValueError( + f"Tensor shape {tuple(t.shape)} is not compatible with CODEBOOK_BLOCK_DIM={block_dim}" + ) + return t.contiguous().view(-1, block_dim), (int(t.shape[0]), int(t.shape[1])) + + +def _unblockify_weight(blocks: Tensor, original_shape: tuple[int, int]) -> Tensor: + return blocks.contiguous().view(original_shape) + + +def _should_codebook_quantize(name: str, t: Tensor, h: Hyperparameters) -> bool: + return ( + t.is_floating_point() + and t.ndim == 2 + and t.shape[1] % h.codebook_block_dim == 0 + and name.startswith("blocks.") + and any(name.endswith(pattern) for pattern in CODEBOOK_TARGET_PATTERNS) + ) + + +def normalize_blocks(blocks: Tensor, eps: float = 1e-8) -> tuple[Tensor, Tensor]: + scales = blocks.norm(dim=-1, keepdim=True).clamp_min(eps) + return blocks / scales, scales + + +def _hadamard_sign_vector(name: str, block_dim: int, *, device: torch.device, dtype: torch.dtype) -> Tensor: + rng = random.Random(f"hadamard::{name}::{block_dim}") + return torch.tensor( + [1.0 if rng.getrandbits(1) else -1.0 for _ in range(block_dim)], + device=device, + dtype=dtype, + ) + + +def hadamard_transform(x: Tensor, scale: float = 1.0) -> Tensor: + original_shape = x.shape + dim = x.shape[-1] + padded_dim = 1 << (max(dim, 1) - 1).bit_length() + out = x.reshape(-1, dim) + if padded_dim != dim: + out = F.pad(out, (0, padded_dim - dim)) + h = 1 + while h < padded_dim: + out = out.view(-1, padded_dim // (2 * h), 2, h) + a = out[:, :, 0, :] + b = out[:, :, 1, :] + out = torch.stack((a + b, a - b), dim=2).reshape(-1, padded_dim) + h *= 2 + out = out[:, :dim] + return (out * scale).reshape(*original_shape) + + +def hadamard_rotate_blocks(blocks: Tensor, sign_vec: Tensor, *, enabled: bool = True) -> Tensor: + if not enabled: + return blocks + return hadamard_transform(blocks * sign_vec, scale=blocks.shape[-1] ** -0.5) + + +def hadamard_unrotate_blocks(blocks: Tensor, sign_vec: Tensor, *, enabled: bool = True) -> Tensor: + if not enabled: + return blocks + return hadamard_transform(blocks, scale=blocks.shape[-1] ** -0.5) * sign_vec + + +def get_norm12() -> Tensor: + return torch.tensor([ + [3, 1, 1, 1, 3, 3, 3, 3], + [1, 3, 1, 1, 3, 3, 3, 3], + [1, 1, 3, 1, 3, 3, 3, 3], + [1, 1, 1, 3, 3, 3, 3, 3], + [3, 3, 3, 1, 3, 3, 1, 1], + [3, 3, 3, 1, 3, 1, 3, 1], + [3, 3, 3, 1, 1, 3, 3, 1], + [3, 3, 3, 1, 3, 1, 1, 3], + [3, 3, 3, 1, 1, 3, 1, 3], + [3, 3, 3, 1, 1, 1, 3, 3], + [3, 3, 1, 3, 3, 3, 1, 1], + [3, 3, 1, 3, 3, 1, 3, 1], + [3, 3, 1, 3, 1, 3, 3, 1], + [3, 3, 1, 3, 3, 1, 1, 3], + [3, 3, 1, 3, 1, 3, 1, 3], + [3, 3, 1, 3, 1, 1, 3, 3], + [3, 1, 3, 3, 3, 3, 1, 1], + [3, 1, 3, 3, 3, 1, 3, 1], + [3, 1, 3, 3, 1, 3, 3, 1], + [3, 1, 3, 3, 3, 1, 1, 3], + [3, 1, 3, 3, 1, 3, 1, 3], + [1, 3, 3, 3, 1, 1, 3, 3], + [1, 3, 3, 3, 3, 3, 1, 1], + [1, 3, 3, 3, 3, 1, 3, 1], + [1, 3, 3, 3, 1, 3, 3, 1], + [1, 3, 3, 3, 3, 1, 1, 3], + [1, 3, 3, 3, 1, 3, 1, 3], + [1, 1, 3, 3, 1, 3, 3, 3], + [3, 3, 1, 1, 3, 3, 3, 1], + ], dtype=torch.float32) / 2 + + +def get_packed_abs_grid() -> Tensor: + intr = torch.arange(-4, 4) + d8 = torch.cartesian_prod(*[intr] * 8).float() + 0.5 + d8_even = d8.sum(dim=-1) % 2 == 0 + d8_small = d8.norm(dim=-1).square() <= 10 + d8_abs = torch.unique(d8[torch.where(d8_even & d8_small)[0]].abs(), dim=0) + codebook_abs = torch.cat([d8_abs, get_norm12()], dim=0) + codebook_abs = codebook_abs[:, [0, 2, 4, 6, 1, 3, 5, 7]] + codebook_abs[:, 7] *= (1 - 2 * (codebook_abs.sum(1) % 2)) + codebook_abs = (codebook_abs * 2 + 8).to(torch.int32) + acc = codebook_abs[:, 0] + for i in range(7): + acc = acc | (codebook_abs[:, i + 1] << ((i + 1) * 4)) + return acc + + +def get_abs_grid() -> Tensor: + intr = torch.arange(-4, 4) + d8 = torch.cartesian_prod(*[intr] * 8).float() + 0.5 + d8_even = d8.sum(dim=-1) % 2 == 0 + d8_small = d8.norm(dim=-1).square() <= 10 + d8_abs = torch.unique(d8[torch.where(d8_even & d8_small)[0]].abs(), dim=0) + return torch.cat([d8_abs, get_norm12()], dim=0) + + +def get_full_grid(packed_abs_grid: Tensor) -> tuple[Tensor, Tensor]: + codebook = torch.zeros(1 << 16, 8) + parity_idx: list[int] = [] + shuffle_map = [0, 4, 1, 5, 2, 6, 3, 7] + for code in range(1 << 16): + signs = code & 255 + abs_idx = code >> 8 + parity = 0 + for bit in range(8): + parity ^= (signs >> bit) & 1 + signs ^= parity + abs_code = packed_abs_grid[abs_idx].item() + for i, shuffled in enumerate(shuffle_map): + codebook[code, i] = (((abs_code >> (4 * shuffled)) & 15) - 8) * 0.5 + if (signs >> shuffled) & 1: + codebook[code, i] *= -1 + if parity: + codebook[code] -= 0.25 + parity_idx.append(code) + else: + codebook[code] += 0.25 + return codebook, torch.tensor(parity_idx, dtype=torch.long) + + +_E8P_PACKED_ABS = get_packed_abs_grid() +_E8P_GRID, _E8P_PARITY_IDX = get_full_grid(_E8P_PACKED_ABS) + + +class E8P12Codebook(nn.Module): + def __init__(self) -> None: + super().__init__() + self.codesz = 8 + self.register_buffer("grid", _E8P_GRID) + self.register_buffer("grid_norm", _E8P_GRID.norm(dim=-1).square()) + grid_part = _E8P_GRID[_E8P_PARITY_IDX] + 0.25 + grid_part = grid_part[ + torch.where( + ((grid_part[:, :7] < 0).sum(dim=-1) <= 1) + & (grid_part[:, :7].min(dim=-1).values >= -0.5) + )[0] + ] + abs_grid = get_abs_grid() + self.register_buffer("grid_part", grid_part) + self.register_buffer("grid_part_norm", grid_part.norm(dim=-1).square()) + self.register_buffer("grid_abs_odd", abs_grid.sum(dim=-1) % 2 == 1) + self.register_buffer("part_abs_map", self.round(grid_part.abs(), abs_grid, abs_grid.norm(dim=-1).square())[1]) + self.register_buffer("bit_map", 2 ** torch.arange(8)) + + def round(self, x: Tensor, grid: Tensor, grid_norm: Tensor) -> tuple[Tensor, Tensor]: + idx = (2 * x @ grid.t() - grid_norm).argmax(dim=-1) + return grid[idx], idx + + def fast_quantize_part(self, x: Tensor, parity: bool) -> tuple[Tensor, Tensor, Tensor]: + x_part = torch.abs(x) + x_odd = torch.where((x < 0).sum(dim=-1) % 2 != 0)[0] + x_part[x_odd, 7] = -x_part[x_odd, 7] + mask = 1 - 2 * (x < 0).to(torch.float32) + mask[x_odd, 7] = -mask[x_odd, 7] + rounded, rounded_idx = self.round(x_part, self.grid_part, self.grid_part_norm) + vals = rounded * mask + err = (x - vals).norm(dim=-1) + abs_idx = self.part_abs_map[rounded_idx] + sign_mask = ((rounded < 0) ^ (mask < 0))[:, [0, 2, 4, 6, 1, 3, 5, 7]] + sign_mask[:, 7] ^= self.grid_abs_odd[abs_idx] + sign_mask[:, 0] ^= parity + idx = (abs_idx << 8) + (sign_mask * self.bit_map).sum(dim=-1).int() + return vals, idx, err + + def quantize(self, x: Tensor) -> tuple[Tensor, Tensor]: + plus_vals, plus_idx, plus_err = self.fast_quantize_part(x + 0.25, True) + minus_vals, minus_idx, minus_err = self.fast_quantize_part(x - 0.25, False) + use_plus = plus_err < minus_err + vals = torch.where(use_plus.unsqueeze(-1), plus_vals - 0.25, minus_vals + 0.25) + idx = torch.where(use_plus, plus_idx, minus_idx) + return vals, idx + + def decode(self, idxs: Tensor) -> Tensor: + idxs_long = idxs.long().reshape(-1) + vals = self.grid.index_select(0, idxs_long) + return vals.view(*idxs.shape, self.codesz) + + +_E8P_CODEBOOK_CACHE: dict[str, E8P12Codebook] = {} + + +def _e8p_cache_key(device: torch.device | None) -> str: + if device is None: + return "cpu" + return f"{device.type}:{device.index}" + + +def _get_e8p_codebook(device: torch.device | None) -> E8P12Codebook: + key = _e8p_cache_key(device) + cached = _E8P_CODEBOOK_CACHE.get(key) + if cached is not None: + return cached + codebook = E8P12Codebook() + if device is not None: + codebook = codebook.to(device=device, dtype=torch.float32) + else: + codebook = codebook.to(dtype=torch.float32) + codebook.eval() + _E8P_CODEBOOK_CACHE[key] = codebook + return codebook + + +@torch.no_grad() +def _quantize_e8p_blocks(blocks: Tensor, lattice_scale: float) -> tuple[Tensor, Tensor]: + codebook = _get_e8p_codebook(blocks.device) + scaled = blocks.to(dtype=torch.float32) * float(lattice_scale) + vals, idxs = codebook.quantize(scaled) + return vals / float(lattice_scale), idxs.long() + + +@torch.no_grad() +def _decode_e8p_blocks( + idxs: Tensor, + lattice_scale: float, + *, + device: torch.device, + dtype: torch.dtype, +) -> Tensor: + codebook = _get_e8p_codebook(device) + vals = codebook.decode(idxs.to(device=device)) + return (vals / float(lattice_scale)).to(dtype=dtype) + + +def collect_hessians( + model: nn.Module, + train_loader: DistributedTokenLoader, + h: Hyperparameters, + device: torch.device, + target_names: set[str], + n_calibration_batches: int = 64, +) -> dict[str, Tensor]: + """Run calibration batches and collect H = X^T X for selected CastedLinear layers.""" + if n_calibration_batches <= 0 or not target_names: + return {} + hessians: dict[str, Tensor] = {} + hooks = [] + was_training = model.training + + def make_hook(name: str): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + return hook_fn + + for module_name, module in model.named_modules(): + if not isinstance(module, CastedLinear): + continue + weight_name = f"{module_name}.weight" + if weight_name in target_names: + hooks.append(module.register_forward_hook(make_hook(weight_name))) + + model.eval() + with torch.no_grad(): + for _ in range(n_calibration_batches): + x, _ = train_loader.next_batch( + h.train_batch_tokens, + h.train_seq_len, + h.grad_accum_steps, + ) + model.forward_logits(x) + + for hook in hooks: + hook.remove() + + world = dist.get_world_size() if dist.is_available() and dist.is_initialized() else 1 + for name, hessian in hessians.items(): + if world > 1: + dist.all_reduce(hessian, op=dist.ReduceOp.SUM) + hessians[name] = hessian.cpu() / float(max(world * n_calibration_batches, 1)) + + if was_training: + model.train() + return hessians + + +def _identity_block_metrics(shape: tuple[int, int], block_dim: int, device: torch.device) -> Tensor: + num_positions = shape[1] // block_dim + eye = torch.eye(block_dim, device=device, dtype=torch.float32) + return eye.unsqueeze(0).repeat(num_positions, 1, 1) + + +def _block_metrics_from_hessian( + shape: tuple[int, int], + block_dim: int, + sign_vec: Tensor, + hessian: Tensor | None, + *, + use_hadamard: bool, + damp_factor: float, + device: torch.device, +) -> Tensor: + if hessian is None or hessian.ndim != 2 or hessian.shape != (shape[1], shape[1]): + return _identity_block_metrics(shape, block_dim, device) + num_positions = shape[1] // block_dim + hessian_work = hessian.to(device=device, dtype=torch.float32).clone() + diag = hessian_work.diag() + dead = diag <= 0 + if dead.any(): + hessian_work[dead, dead] = 1.0 + damp = float(damp_factor * hessian_work.diag().mean().clamp_min(1e-8).item()) + hessian_work.diagonal().add_(damp) + reshaped = hessian_work.view(num_positions, block_dim, num_positions, block_dim) + block_idx = torch.arange(num_positions, device=device) + block_metrics = reshaped[block_idx, :, block_idx, :] + if use_hadamard: + unrotate = hadamard_unrotate_blocks( + torch.eye(block_dim, device=device, dtype=torch.float32), + sign_vec, + enabled=True, + ) + block_metrics = torch.matmul( + unrotate.unsqueeze(0), + torch.matmul(block_metrics, unrotate.t().unsqueeze(0)), + ) + block_metrics = 0.5 * (block_metrics + block_metrics.transpose(-1, -2)) + block_metrics.diagonal(dim1=-2, dim2=-1).add_(1e-6) + return block_metrics + + +def _metric_optimal_scales(rotated_blocks: Tensor, codebook_blocks: Tensor, metrics: Tensor) -> Tensor: + numer = torch.einsum("npd,pde,npe->np", codebook_blocks, metrics, rotated_blocks) + denom = torch.einsum("npd,pde,npe->np", codebook_blocks, metrics, codebook_blocks).clamp_min(1e-8) + scales = (numer / denom).clamp_min(1e-8) + max_fp16 = float(torch.finfo(torch.float16).max) + return scales.clamp(max=max_fp16).unsqueeze(-1) + + +class CodebookQuantizer: + def __init__(self, h: Hyperparameters, model: nn.Module): + _validate_codebook_hparams(h) + self.h = h + self.states: dict[str, dict[str, object]] = {} + self.target_modules: list[tuple[str, CastedLinear]] = [] + self.target_names: set[str] = set() + self.last_fit_summary: dict[str, float] = {} + for module_name, module in model.named_modules(): + if not isinstance(module, CastedLinear): + continue + weight_name = f"{module_name}.weight" + if _should_codebook_quantize(weight_name, module.weight, h): + self.target_modules.append((weight_name, module)) + self.target_names.add(weight_name) + + @torch.no_grad() + def _fit_tensor(self, name: str, module: CastedLinear, hessian: Tensor | None) -> tuple[str, dict[str, float]]: + weight = module.weight.detach().float() + blocks, shape = _blockify_weight(weight, self.h.codebook_block_dim) + num_rows, num_cols = shape + num_positions = num_cols // self.h.codebook_block_dim + sign_vec = _hadamard_sign_vector(name, self.h.codebook_block_dim, device=blocks.device, dtype=torch.float32) + rotated_blocks = hadamard_rotate_blocks(blocks, sign_vec, enabled=self.h.codebook_use_hadamard) + rotated_grid = rotated_blocks.view(num_rows, num_positions, self.h.codebook_block_dim) + target_dirs, _ = normalize_blocks(rotated_grid) + fixed_dirs_flat, fixed_idx = _quantize_e8p_blocks( + target_dirs.reshape(-1, self.h.codebook_block_dim), + self.h.codebook_lattice_scale, + ) + fixed_dirs = fixed_dirs_flat.view(num_rows, num_positions, self.h.codebook_block_dim) + metrics = _block_metrics_from_hessian( + shape, + self.h.codebook_block_dim, + sign_vec, + hessian, + use_hadamard=self.h.codebook_use_hadamard, + damp_factor=self.h.codebook_hessian_damp, + device=blocks.device, + ) + scales = _metric_optimal_scales(rotated_grid, fixed_dirs, metrics) + scales_fp16 = scales.to(dtype=torch.float16) + recon_rotated = fixed_dirs * scales_fp16.to(dtype=torch.float32) + recon_blocks = hadamard_unrotate_blocks( + recon_rotated.view(-1, self.h.codebook_block_dim), + sign_vec, + enabled=self.h.codebook_use_hadamard, + ) + diff = blocks - recon_blocks + mse = diff.square().mean() + energy = blocks.square().mean().clamp_min(1e-12) + stats = { + "num_weights": float(weight.numel()), + "mse": float(mse.item()), + "rel_mse": float((mse / energy).item()), + "scale_min": float(scales_fp16.min().item()), + "scale_mean": float(scales_fp16.float().mean().item()), + "scale_max": float(scales_fp16.max().item()), + } + self.states[name] = { + "shape": shape, + "block_dim": self.h.codebook_block_dim, + "fixed_idx": fixed_idx.to(dtype=torch.uint16).cpu().contiguous(), + "scales": scales_fp16.reshape(-1).cpu().contiguous(), + "stats": stats, + } + return name, stats + + @torch.no_grad() + def fit(self, model: nn.Module, hessians: dict[str, Tensor]) -> None: + if not self.target_modules: + if self.h.is_main_process: + log("codebook:no eligible tensors found") + return + total_fp_weights = sum( + int(t.numel()) + for _, t in model.state_dict().items() + if t.is_floating_point() + ) + fit_items = [self._fit_tensor(name, module, hessians.get(name)) for name, module in self.target_modules] + target_weights = sum(int(stats["num_weights"]) for _, stats in fit_items) + weighted_rel_mse = sum(float(stats["rel_mse"]) * int(stats["num_weights"]) for _, stats in fit_items) + self.last_fit_summary = { + "target_tensors": float(len(fit_items)), + "target_weights": float(target_weights), + "coverage": target_weights / max(total_fp_weights, 1), + "rel_mse": weighted_rel_mse / max(target_weights, 1), + "target_bpw": 4.0, + } + if self.h.is_main_process: + log( + f"codebook:fit target_tensors:{len(fit_items)} target_weights:{target_weights} " + f"coverage:{self.last_fit_summary['coverage']:.4%} rel_mse:{self.last_fit_summary['rel_mse']:.6e} " + f"target_bpw:{self.last_fit_summary['target_bpw']:.4f}" + ) + + def build_export( + self, + state_dict: dict[str, Tensor], + ) -> tuple[dict[str, Tensor], dict[str, object], dict[str, object]]: + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + export_stats: dict[str, object] = {"tensors": {}} + total_payload_bytes = 0 + target_payload_bytes = 0 + int8_fallback_payload_bytes = 0 + int8_fallback_weights = 0 + passthrough_payload_bytes = 0 + passthrough_weights = 0 + target_weights = 0 + total_fp_weights = sum(int(t.numel()) for _, t in state_dict.items() if t.is_floating_point()) + + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + if name in self.states: + state = self.states[name] + idx = state["fixed_idx"] + scales = state["scales"] + result[name + ".idx"] = idx + result[name + ".scale"] = scales + meta[name] = { + "type": "codebook_e8p_fp16_scale", + "shape": list(state["shape"]), + "block_dim": int(state["block_dim"]), + "hadamard": bool(self.h.codebook_use_hadamard), + "lattice_scale": float(self.h.codebook_lattice_scale), + } + payload_bytes = ( + idx.numel() * idx.element_size() + + scales.numel() * scales.element_size() + ) + total_payload_bytes += payload_bytes + target_payload_bytes += payload_bytes + target_weights += int(state["stats"]["num_weights"]) + export_stats["tensors"][name] = { + **state["stats"], + "payload_bytes": payload_bytes, + "bpw": (8.0 * payload_bytes) / max(float(state["stats"]["num_weights"]), 1.0), + } + continue + + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + payload_bytes = result[name].numel() * result[name].element_size() + total_payload_bytes += payload_bytes + passthrough_payload_bytes += payload_bytes + if t.is_floating_point(): + passthrough_weights += int(t.numel()) + continue + + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + payload_bytes = result[name].numel() * result[name].element_size() + total_payload_bytes += payload_bytes + passthrough_payload_bytes += payload_bytes + passthrough_weights += int(t.numel()) + continue + + q, s = quantize_float_tensor(t) + result[name + ".q"] = q.cpu().contiguous() + result[name + ".scale"] = s.cpu().contiguous() + meta[name] = {"type": "int8"} + payload_bytes = ( + result[name + ".q"].numel() * result[name + ".q"].element_size() + + result[name + ".scale"].numel() * result[name + ".scale"].element_size() + ) + total_payload_bytes += payload_bytes + int8_fallback_payload_bytes += payload_bytes + int8_fallback_weights += int(t.numel()) + + export_stats["summary"] = { + "target_tensors": len(self.states), + "target_weights": target_weights, + "coverage": target_weights / max(total_fp_weights, 1), + "target_payload_bytes": target_payload_bytes, + "target_bpw": (8.0 * target_payload_bytes) / max(target_weights, 1), + "int8_fallback_weights": int8_fallback_weights, + "int8_fallback_payload_bytes": int8_fallback_payload_bytes, + "passthrough_weights": passthrough_weights, + "passthrough_payload_bytes": passthrough_payload_bytes, + "payload_bytes_before_torchsave": total_payload_bytes, + "effective_payload_bpw_all_weights": (8.0 * total_payload_bytes) / max(total_fp_weights, 1), + "model_fp_weights": total_fp_weights, + } + return result, meta, export_stats + + +def dequantize_state_dict_codebook( + result: dict[str, Tensor], + meta: dict[str, object], + template_sd: dict[str, Tensor], +) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + if info in ("passthrough", "passthrough_ctrl"): + t = result[name] + if t.dtype == torch.float16 and orig.dtype in (torch.float32, torch.bfloat16): + t = t.to(orig.dtype) + out[name] = t + continue + if not isinstance(info, dict): + raise ValueError(f"Unsupported compression metadata for {name}: {info!r}") + if info.get("type") == "int8": + out[name] = dequantize_int8_rows( + result[name + ".q"], + result[name + ".scale"], + device=torch.device("cpu"), + dtype=orig.dtype, + ) + continue + if info.get("type") != "codebook_e8p_fp16_scale": + raise ValueError(f"Unsupported compression metadata for {name}: {info!r}") + shape = tuple(int(x) for x in info["shape"]) + block_dim = int(info["block_dim"]) + idx = result[name + ".idx"].to(dtype=torch.int64) + scales = result[name + ".scale"].to(dtype=torch.float32).view(-1, 1) + fixed_blocks = _decode_e8p_blocks( + idx, + float(info["lattice_scale"]), + device=torch.device("cpu"), + dtype=torch.float32, + ) + rotated_blocks = fixed_blocks * scales + sign_vec = _hadamard_sign_vector(name, block_dim, device=torch.device("cpu"), dtype=torch.float32) + blocks = hadamard_unrotate_blocks( + rotated_blocks, + sign_vec, + enabled=bool(info.get("hadamard", True)), + ) + out[name] = _unblockify_weight(blocks, shape).to(orig.dtype) + return out + + +_BSHF_MAGIC = b"BSHF" + + +def _byte_shuffle(data: bytes, stride: int = 2) -> bytes: + """Transpose byte stream by stride position for better compression.""" + if stride <= 1 or len(data) < stride: + return data + src = np.frombuffer(data, dtype=np.uint8) + n = len(src) + out = np.empty(n, dtype=np.uint8) + dest_off = 0 + for pos in range(stride): + chunk = src[pos::stride] + out[dest_off:dest_off + len(chunk)] = chunk + dest_off += len(chunk) + return _BSHF_MAGIC + bytes([stride]) + out.tobytes() + + +def _byte_unshuffle(data: bytes) -> bytes: + """Inverse of _byte_shuffle. Auto-detects BSHF magic header.""" + if len(data) < 5 or data[:4] != _BSHF_MAGIC: + return data + stride = data[4] + if stride < 2: + return data[5:] + payload = np.frombuffer(data, dtype=np.uint8, offset=5) + n = len(payload) + out = np.empty(n, dtype=np.uint8) + src_off = 0 + for pos in range(stride): + chunk_len = n // stride + (1 if pos < n % stride else 0) + out[pos::stride][:chunk_len] = payload[src_off:src_off + chunk_len] + src_off += chunk_len + return out.tobytes() + + +def _compress(data: bytes, compressor: str, byte_shuffle: bool = True) -> bytes: + if byte_shuffle: + data = _byte_shuffle(data) + if compressor == "lzma": + return lzma.compress(data, preset=6) + elif compressor == "brotli": + import brotli + return brotli.compress(data, quality=11) + raise ValueError(f"Unknown compressor: {compressor!r}") + + +def _decompress(data: bytes, compressor: str, byte_shuffle: bool = True) -> bytes: + if compressor == "lzma": + raw = lzma.decompress(data) + elif compressor == "brotli": + import brotli + raw = brotli.decompress(data) + else: + raise ValueError(f"Unknown compressor: {compressor!r}") + if byte_shuffle: + raw = _byte_unshuffle(raw) + return raw + + +def serialize(h: Hyperparameters, base_model: torch.nn.Module, code: str) -> int: + quantizer = CodebookQuantizer(h, base_model) + device = next(base_model.parameters()).device + code_bytes = len(code.encode("utf-8")) + bytes_total = code_bytes + if h.is_main_process: + torch.save(base_model.state_dict(), h.model_path) + model_bytes = os.path.getsize(h.model_path) + log(f"Serialized model: {model_bytes} bytes") + log(f"Code size: {code_bytes} bytes") + + hessians: dict[str, Tensor] = {} + if quantizer.target_names: + if h.is_main_process: + log("codebook:collecting calibration hessians...") + t0 = time.perf_counter() + calib_loader = DistributedTokenLoader(h.train_files, h.rank, h.world_size, device) + hessians = collect_hessians( + base_model, + calib_loader, + h, + device, + quantizer.target_names, + n_calibration_batches=h.codebook_calibration_batches, + ) + if h.is_main_process: + log(f"codebook:collected {len(hessians)} hessians in {time.perf_counter() - t0:.1f}s") + + if h.is_main_process: + t0 = time.perf_counter() + quantizer.fit(base_model, hessians) + log(f"codebook:fit_time:{time.perf_counter() - t0:.1f}s") + quant_result, quant_meta, quant_stats = quantizer.build_export(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta, "s": quant_stats}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = _compress(quant_raw, h.compressor) + quant_file_bytes = len(quant_blob) + bytes_total = quant_file_bytes + code_bytes + with open(h.quantized_model_path, "wb") as f: + f.write(quant_blob) + summary = quant_stats["summary"] + log( + f"Serialized model codebook+{h.compressor}: {quant_file_bytes} bytes " + f"(payload_before_torchsave:{summary['payload_bytes_before_torchsave']} bytes)" + ) + log( + f"Codebook coverage:{summary['coverage']:.4%} target_bpw:{summary['target_bpw']:.4f} " + f"effective_payload_bpw_all_weights:{summary['effective_payload_bpw_all_weights']:.4f}" + ) + log( + f"Fallback payloads int8_weights:{summary['int8_fallback_weights']} " + f"int8_bytes:{summary['int8_fallback_payload_bytes']} " + f"passthrough_weights:{summary['passthrough_weights']} " + f"passthrough_bytes:{summary['passthrough_payload_bytes']}" + ) + log(f"Total submission size codebook+{h.compressor}: {bytes_total} bytes") + return bytes_total + + +def deserialize(h: Hyperparameters, device: torch.device) -> GPT: + eval_model = GPT(h).to(device).bfloat16() + restore_fp32_params(eval_model) + + template_sd = eval_model.state_dict() + + with open(h.quantized_model_path, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(_decompress(quant_blob_disk, h.compressor)), + map_location="cpu", + ) + deq_state = dequantize_state_dict_codebook(quant_state["w"], quant_state["m"], template_sd) + eval_model.load_state_dict(deq_state, strict=True) + + return eval_model + +# ---------------------------------------- +# Evaluation +# ---------------------------------------- + +def _loss_bpb(loss_sum, token_count, byte_count) -> tuple[float, float]: + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + return val_loss, val_bpb + + +def eval_val( + h: Hyperparameters, + device: torch.device, + val_data: ValidationData, + model: nn.Module +) -> tuple[float, float]: + seq_len = h.eval_seq_len + local_batch_tokens = h.val_batch_tokens // (h.world_size * h.grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={h.val_batch_tokens}, WORLD_SIZE={h.world_size}, " + f"GRAD_ACCUM_STEPS={h.grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_data.val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * h.rank) // h.world_size + seq_end = (total_seqs * (h.rank + 1)) // h.world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_data.val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = val_data.base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (val_data.has_leading_space_lut[tgt_ids] & ~val_data.is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + model.train() + return _loss_bpb(val_loss_sum, val_token_count, val_byte_count) + + +def eval_val_sliding( + h: Hyperparameters, + device: torch.device, + val_data: ValidationData, + base_model: nn.Module, + batch_seqs: int = 32 +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + base_model.eval() + logits_fn = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + + seq_len = h.eval_seq_len + context_size = seq_len - h.eval_stride + total_tokens = val_data.val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, h.eval_stride) + if ws + context_size < total_tokens] + + total_windows = len(window_starts) + my_s = (total_windows * h.rank) // h.world_size + my_e = (total_windows * (h.rank + 1)) // h.world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + we = min(ws + seq_len, total_tokens) + wlen = we - ws + wlens.append(wlen) + chunk = val_data.val_tokens[ws:we + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = logits_fn(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else context_size + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = val_data.base_bytes_lut[tgt].to(torch.float64) + tb += (val_data.has_leading_space_lut[tgt] & ~val_data.is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + base_model.train() + return _loss_bpb(loss_sum, token_count, byte_count) + + +def timed_eval(label: str, fn, *args, **kwargs) -> tuple[float, float]: + torch.cuda.synchronize() + t0 = time.perf_counter() + val_loss, val_bpb = fn(*args, **kwargs) + torch.cuda.synchronize() + elapsed_ms = 1000.0 * (time.perf_counter() - t0) + log(f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms") + return val_loss, val_bpb + + +def run_evals( + h: Hyperparameters, + device: torch.device, + val_data: ValidationData, + eval_model: torch.nn.Module +): + compiled_model = torch.compile(eval_model, dynamic=False, fullgraph=True) + timed_eval("final_codebook_roundtrip", eval_val, h, device, val_data, compiled_model) + if h.sliding_window_enabled: + timed_eval("final_codebook_sliding_window", eval_val_sliding, h, device, val_data, eval_model) + +# ----------------------------- +# Training +# ----------------------------- + +def train_model(h: Hyperparameters, device: torch.device, val_data: ValidationData) -> None: + # Set up model + base_model = GPT(h).to(device).bfloat16() + restore_fp32_params(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + if h.distributed: + model = DDP(compiled_model, device_ids=[h.local_rank], broadcast_buffers=False) + else: + model = compiled_model + log(f"model_params:{sum(p.numel() for p in base_model.parameters())}") + + # Set up optimizer and load train data + optimizers = Optimizers(h, base_model) + train_loader = DistributedTokenLoader( h.train_files, h.rank, h.world_size, device) + + # Helper functions for training + max_wallclock_ms = 1000.0 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None + if max_wallclock_ms is not None and h.codebook_reserve_seconds > 0: + max_wallclock_ms = max(max_wallclock_ms - h.codebook_reserve_seconds * 1000.0, 0.0) + log(f"codebook:reserving {h.codebook_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms") + + def training_frac(step: int, elapsed_ms: float) -> float: + """Fraction of training completed (0 to 1), using step or wallclock.""" + if max_wallclock_ms is None: + return step / max(h.iterations, 1) + return elapsed_ms / max(max_wallclock_ms, 1e-9) + + def lr_mul(frac: float) -> float: + if h.warmdown_frac <= 0: + return 1.0 + if frac >= 1.0 - h.warmdown_frac: + return max((1.0 - frac) / h.warmdown_frac, h.min_lr) + return 1.0 + + def step_fn(step, lr_scale): + optimizers.zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(h.grad_accum_steps): + if h.distributed: + model.require_backward_grad_sync = micro_step == h.grad_accum_steps - 1 + x, y = train_loader.next_batch(h.train_batch_tokens, h.train_seq_len, h.grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss / h.grad_accum_steps).backward() + train_loss /= h.grad_accum_steps + + frac = min(step / h.muon_momentum_warmup_steps, 1.0) if h.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * h.muon_momentum_warmup_start + frac * h.muon_momentum + for group in optimizers.optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * lr_scale + + if h.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), h.grad_clip_norm) + + optimizers.step() + return train_loss + + # Model warmup + if h.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if warmup_step <= 5 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == h.warmup_steps: + log(f"warmup_step: {warmup_step + 1}/{h.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + optimizers.zero_grad_all() + if h.distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader( + h.train_files, h.rank, h.world_size, device) + + # Training loop + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = h.ema_decay + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == h.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (h.val_loss_every > 0 and step % h.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val(h, device, val_data, model) + log(f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}") + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < h.iterations: + log( + f"stopping_early: wallclock_cap train_time: {training_time_ms:.0f}ms " + f"step: {step}/{h.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + frac = training_frac(step, elapsed_ms) + scale = lr_mul(frac) + train_loss = step_fn(step, scale) + + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + should_log_train = ( + h.train_log_every > 0 + and (step <= 5 or step % h.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + tok_per_sec = step * h.train_batch_tokens / (approx_training_time_ms / 1000.0) + log( + f"{step}/{h.iterations} train_loss: {train_loss.item():.4f} " + f"train_time: {approx_training_time_ms / 60000:.1f}m tok/s: {tok_per_sec:.0f}" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if h.distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Weight averaging + log("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + + return base_model, compiled_model + + +def train_and_eval(h: Hyperparameters, device: torch.device) -> None: + random.seed(h.seed) + np.random.seed(h.seed) + torch.manual_seed(h.seed) + torch.cuda.manual_seed_all(h.seed) + + val_data = ValidationData(h, device) + log(f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob('fineweb_train_*.bin')))}") + log(f"val_tokens: {val_data.val_tokens.numel() - 1}") + + base_model, compiled_model = train_model(h, device, val_data) + timed_eval("pre-codebook post-ema", eval_val, h, device, val_data, compiled_model) + + serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) + if h.distributed: + dist.barrier() + eval_model = deserialize(h, device) + + run_evals(h, device, val_data, eval_model) + + +def main(): + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision("high") + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + torch._dynamo.config.optimize_ddp = False + + h = Hyperparameters() + set_logging_hparams(h) + if h.is_main_process: + os.makedirs("logs", exist_ok=True) + log(100 * "=", console=False) + log("Hyperparameters:", console=True) + for k, v in sorted(vars(type(h)).items()): + if not k.startswith("_"): + log(f" {k}: {v}", console=True) + log(Path(__file__).read_text(encoding="utf-8"), console=False) + log("=" * 100, console=False) + log(f"Running Python {sys.version}", console=False) + log(f"Running PyTorch {torch.__version__}", console=False) + log( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log("=" * 100, console=False) + + train_and_eval(h, device) + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_non_record_16mb/2026-04-07_Codebooks/vq_codebook_13_1024.log b/records/track_non_record_16mb/2026-04-07_Codebooks/vq_codebook_13_1024.log new file mode 100644 index 0000000000..87d3c5c14e --- /dev/null +++ b/records/track_non_record_16mb/2026-04-07_Codebooks/vq_codebook_13_1024.log @@ -0,0 +1,2998 @@ +==================================================================================================== +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + codebook_assignment_entropy_weight: 0.0 + codebook_block_dim: 8 + codebook_calibration_batches: 64 + codebook_entropy_focus: + codebook_entropy_summary_topk: 12 + codebook_hessian_damp: 0.01 + codebook_lattice_scale: 1.03 + codebook_outlier_frac: 0.0 + codebook_outlier_max_count: 2048 + codebook_penalty_enabled: True + codebook_penalty_refresh_every: 64 + codebook_penalty_start_frac: 0.9 + codebook_penalty_weight: 4.0 + codebook_reserve_seconds: 30.0 + codebook_scale_bits: 8 + codebook_soft_snap_alpha: 0.02 + codebook_soft_snap_enabled: False + codebook_soft_snap_start_frac: 0.9 + codebook_soft_snap_update_every: 1 + codebook_use_hadamard: True + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp4096 + distributed: True + ema_decay: 0.99 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + eval_seq_len: 2048 + eval_stride: 64 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/vq_codebook_13_2048_1024.txt + logit_softcap: 30.0 + matrix_lr: 0.02 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_wd: 0.085 + num_heads: 8 + num_kv_heads: 4 + num_layers: 13 + qk_gain_init: 4.0 + quantized_model_path: final_model.codebook.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: vq_codebook_13_2048_1024 + scalar_lr: 0.02 + seed: 1024 + skip_gates_enabled: True + sliding_window_enabled: True + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_4096_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp4096/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + val_batch_tokens: 524288 + val_files: ./data/datasets/fineweb10B_sp4096/fineweb_val_*.bin + val_loss_every: 4000 + ve_dim: 128 + ve_enabled: True + ve_layers: 9,10 + vocab_size: 4096 + warmdown_frac: 0.667 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 13 +import copy +import glob +import io +import lzma +import math +import os +from pathlib import Path +import random +import subprocess +import sys +import time +import uuid + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +from torch import Tensor, nn + +try: + from flash_attn_interface import flash_attn_func as _flash_attn_3_func +except ImportError: + _flash_attn_3_func = None + +def _use_flash_attn() -> bool: + if _flash_attn_3_func is None: + return False + v = os.environ.get("USE_FLASH_ATTN", "1").strip().lower() + return v not in ("0", "false", "no", "off") + +# ---------------------------------------- +# Hyperparameters +# ---------------------------------------- + +class Hyperparameters(): + # Experiment settings + data_dir = os.environ.get('DATA_DIR', './data/') + seed = int(os.environ.get('SEED', 1337)) + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + + # Training length + iterations = int(os.environ.get('ITERATIONS', 20000)) + warmdown_frac = float(os.environ.get('WARMDOWN_FRAC', 0.667)) + warmup_steps = int(os.environ.get('WARMUP_STEPS', 20)) + train_batch_tokens = int(os.environ.get('TRAIN_BATCH_TOKENS', 2048 * 48 * 8)) + train_seq_len = int(os.environ.get('TRAIN_SEQ_LEN', 2048)) + eval_seq_len = int(os.environ.get('EVAL_SEQ_LEN', 2048)) + max_wallclock_seconds = float(os.environ.get('MAX_WALLCLOCK_SECONDS', 600.0)) + train_log_every = int(os.environ.get('TRAIN_LOG_EVERY', 500)) + + # Validation/Evals + val_batch_tokens = int(os.environ.get('VAL_BATCH_TOKENS', 2048 * 32 * 8)) + val_loss_every = int(os.environ.get('VAL_LOSS_EVERY', 4000)) + sliding_window_enabled = bool(int(os.environ.get('SLIDING_WINDOW_ENABLED', '1'))) + + # Model architecture + vocab_size = int(os.environ.get('VOCAB_SIZE', 4096)) + num_layers = int(os.environ.get('NUM_LAYERS', 11)) + xsa_last_n = int(os.environ.get('XSA_LAST_N', 11)) + num_kv_heads = int(os.environ.get('NUM_KV_HEADS', 4)) + model_dim = int(os.environ.get('MODEL_DIM', 512)) + embedding_dim = int(os.environ.get('EMBEDDING_DIM', 512)) + num_heads = int(os.environ.get('NUM_HEADS', 8)) + mlp_mult = float(os.environ.get('MLP_MULT', 4.0)) + skip_gates_enabled = bool(int(os.environ.get('SKIP_GATES_ENABLED', '1'))) + tie_embeddings = bool(int(os.environ.get('TIE_EMBEDDINGS', '1'))) + logit_softcap = float(os.environ.get('LOGIT_SOFTCAP', 30.0)) + rope_base = float(os.environ.get('ROPE_BASE', 10000.0)) + rope_dims = int(os.environ.get('ROPE_DIMS', 16)) + rope_train_seq_len = int(os.environ.get('ROPE_TRAIN_SEQ_LEN', 2048)) + ln_scale = bool(int(os.environ.get('LN_SCALE', '1'))) + ve_enabled = bool(int(os.environ.get('VE_ENABLED', '1'))) + ve_dim = int(os.environ.get('VE_DIM', 128)) + ve_layers = os.environ.get('VE_LAYERS', '9,10') + qk_gain_init = float(os.environ.get('QK_GAIN_INIT', 4.0)) + + # Optimizer + min_lr = float(os.environ.get('MIN_LR', 0.0)) + embed_lr = float(os.environ.get('EMBED_LR', 0.6)) + head_lr = float(os.environ.get('HEAD_LR', 0.008)) + tied_embed_lr = float(os.environ.get('TIED_EMBED_LR', 0.03)) + tied_embed_init_std = float(os.environ.get('TIED_EMBED_INIT_STD', 0.005)) + matrix_lr = float(os.environ.get('MATRIX_LR', 0.02)) + scalar_lr = float(os.environ.get('SCALAR_LR', 0.02)) + muon_momentum = float(os.environ.get('MUON_MOMENTUM', 0.99)) + muon_backend_steps = int(os.environ.get('MUON_BACKEND_STEPS', 5)) + muon_momentum_warmup_start = float(os.environ.get('MUON_MOMENTUM_WARMUP_START', 0.92)) + muon_momentum_warmup_steps = int(os.environ.get('MUON_MOMENTUM_WARMUP_STEPS', 1500)) + beta1 = float(os.environ.get('BETA1', 0.9)) + beta2 = float(os.environ.get('BETA2', 0.95)) + adam_eps = float(os.environ.get('ADAM_EPS', 1e-8)) + grad_clip_norm = float(os.environ.get('GRAD_CLIP_NORM', 0.3)) + eval_stride = int(os.environ.get('EVAL_STRIDE', 64)) + muon_beta2 = float(os.environ.get('MUON_BETA2', 0.95)) + adam_wd = float(os.environ.get('ADAM_WD', 0.02)) + muon_wd = float(os.environ.get('MUON_WD', 0.085)) + embed_wd = float(os.environ.get('EMBED_WD', 0.085)) + ema_decay = float(os.environ.get('EMA_DECAY', 1.0)) + + # Compression + compressor = os.environ.get('COMPRESSOR', 'brotli') #(lzma or brotli) + codebook_block_dim = int(os.environ.get("CODEBOOK_BLOCK_DIM", 8)) + codebook_use_hadamard = bool(int(os.environ.get("CODEBOOK_USE_HADAMARD", "1"))) + codebook_calibration_batches = int(os.environ.get("CODEBOOK_CALIBRATION_BATCHES", 64)) + codebook_reserve_seconds = float(os.environ.get("CODEBOOK_RESERVE_SECONDS", 10.0)) + codebook_hessian_damp = float(os.environ.get("CODEBOOK_HESSIAN_DAMP", 0.01)) + codebook_lattice_scale = float(os.environ.get("CODEBOOK_LATTICE_SCALE", 1.03)) + codebook_scale_bits = int(os.environ.get("CODEBOOK_SCALE_BITS", 8)) + codebook_assignment_entropy_weight = float(os.environ.get("CODEBOOK_ASSIGNMENT_ENTROPY_WEIGHT", 0.01)) + codebook_soft_snap_enabled = bool(int(os.environ.get("CODEBOOK_SOFT_SNAP_ENABLED", "0"))) + codebook_soft_snap_start_frac = float(os.environ.get("CODEBOOK_SOFT_SNAP_START_FRAC", 0.9)) + codebook_soft_snap_alpha = float(os.environ.get("CODEBOOK_SOFT_SNAP_ALPHA", 0.02)) + codebook_soft_snap_update_every = int(os.environ.get("CODEBOOK_SOFT_SNAP_UPDATE_EVERY", 1)) + codebook_penalty_enabled = bool( + int(os.environ.get("CODEBOOK_PENALTY_ENABLED", os.environ.get("CODEBOOK_STALE_PENALTY_ENABLED", "0"))) + ) + codebook_penalty_start_frac = float( + os.environ.get("CODEBOOK_PENALTY_START_FRAC", os.environ.get("CODEBOOK_STALE_PENALTY_START_FRAC", 0.75)) + ) + codebook_penalty_weight = float( + os.environ.get("CODEBOOK_PENALTY_WEIGHT", os.environ.get("CODEBOOK_STALE_PENALTY_WEIGHT", 0.01)) + ) + codebook_penalty_refresh_every = int( + os.environ.get("CODEBOOK_PENALTY_REFRESH_EVERY", os.environ.get("CODEBOOK_STALE_PENALTY_REFRESH_EVERY", 16)) + ) + codebook_outlier_frac = float(os.environ.get("CODEBOOK_OUTLIER_FRAC", 0.0)) + codebook_outlier_max_count = int(os.environ.get("CODEBOOK_OUTLIER_MAX_COUNT", 0)) + codebook_entropy_summary_topk = int(os.environ.get("CODEBOOK_ENTROPY_SUMMARY_TOPK", 12)) + codebook_entropy_focus = os.environ.get("CODEBOOK_ENTROPY_FOCUS", "").strip() + + # Distributed setup + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + is_main_process = rank == 0 + grad_accum_steps = 8 // world_size + + # Data paths + datasets_dir = os.path.join(data_dir, 'datasets', f'fineweb10B_sp{vocab_size}') + train_files = os.path.join(datasets_dir, 'fineweb_train_*.bin') + val_files = os.path.join(datasets_dir, 'fineweb_val_*.bin') + tokenizer_path = os.path.join(data_dir, 'tokenizers', f'fineweb_{vocab_size}_bpe.model') + + # Experiment files + logfile = f"logs/{run_id}.txt" + model_path = "final_model.pt" + quantized_model_path = "final_model.codebook.ptz" + +# ---------------------------------------- +# Global Logging Function +# ---------------------------------------- + +_logger_hparams = None + + +def set_logging_hparams(h: Hyperparameters) -> None: + global _logger_hparams + _logger_hparams = h + + +def log(msg, console: bool = True) -> None: + if _logger_hparams is None: + print(msg) + if _logger_hparams.is_main_process: + if console: + print(msg) + if _logger_hparams.logfile is not None: + with open(_logger_hparams.logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + +# ---------------------------------------- +# Data Loading +# ---------------------------------------- + +class ValidationData: + def __init__(self, h: Hyperparameters, device: torch.device): + if not h.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {h.tokenizer_path}") + self.sp = spm.SentencePieceProcessor(model_file=h.tokenizer_path) + if int(self.sp.vocab_size()) != h.vocab_size: + raise ValueError( + f"VOCAB_SIZE={h.vocab_size} does not match tokenizer vocab_size={int(self.sp.vocab_size())}" + ) + + self.val_tokens = load_validation_tokens(h.val_files, h.eval_seq_len) + self.base_bytes_lut, self.has_leading_space_lut, self.is_boundary_token_lut = ( + build_sentencepiece_luts(self.sp, h.vocab_size, device)) + + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + # The BPB calculation assumes "▁" is its own token so that leading-space bytes + # are counted correctly. See https://github.com/openai/parameter-golf/issues/897 + assert sp.piece_to_id("\u2581") != sp.unk_id(), \ + "Tokenizer must have '▁' (space) as its own token for correct BPB byte counting" + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" int: + key = str(file) + cached = _SHARD_NTOKENS_CACHE.get(key) + if cached is not None: + return cached + header = np.fromfile(file, dtype=" np.memmap: + key = str(file) + mm = _MMAP_CACHE.get(key) + if mm is not None: + return mm + n = _read_num_tokens(file) + mm = np.memmap(file, mode="r", dtype=" int: + if n <= 1: + return 1 + while True: + s = int(self._rng.integers(1, n)) + if math.gcd(s, n) == 1: + return s + + def _reset_cursor(self, si: int, seq_len: int) -> None: + nt = int(self._num_tokens[si]) + max_phase = min(seq_len - 1, max(0, nt - seq_len - 1)) + phase = int(self._rng.integers(max_phase + 1)) if max_phase > 0 else 0 + bc = (nt - 1 - phase) // seq_len + self._cursor_phase[si] = phase + self._cursor_block_count[si] = bc + self._cursor_next[si] = 0 + self._cursor_start[si] = int(self._rng.integers(bc)) if bc > 1 else 0 + self._cursor_stride[si] = self._pick_coprime_stride(bc) + self._cursor_init[si] = True + + def _ensure_cursor(self, si: int, seq_len: int) -> None: + if not self._cursor_init[si] or self._cursor_next[si] >= self._cursor_block_count[si]: + self._reset_cursor(si, seq_len) + + def _take_from_shard(self, si: int, seq_len: int, count: int, out: list[tuple[int, int]]) -> None: + rem = count + while rem > 0: + self._ensure_cursor(si, seq_len) + bc = int(self._cursor_block_count[si]) + ni = int(self._cursor_next[si]) + take = min(rem, bc - ni) + phase = int(self._cursor_phase[si]) + start = int(self._cursor_start[si]) + stride = int(self._cursor_stride[si]) + for j in range(take): + bi = (start + (ni + j) * stride) % bc + out.append((si, phase + bi * seq_len)) + self._cursor_next[si] = ni + take + rem -= take + + def _init_pipeline(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> None: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + num_seqs = local_tokens // seq_len + global_num_seqs = num_seqs * self.world_size + self._cfg = (local_tokens, seq_len, num_seqs, global_num_seqs) + bbc = (self._num_tokens - 1) // seq_len + eligible = bbc > 0 + self._eligible_shards = np.nonzero(eligible)[0].astype(np.int64) + self._base_block_counts = bbc[self._eligible_shards].astype(np.int64) + + def _sample_global_windows(self) -> list[tuple[int, int]]: + assert self._cfg is not None and self._eligible_shards is not None + _, seq_len, _, gns = self._cfg + ec = int(self._eligible_shards.size) + progress = min(self._batches_built / 1800.0, 1.0) + remaining = np.empty(ec, dtype=np.float64) + for i, si in enumerate(self._eligible_shards.tolist()): + if self._cursor_init[si]: + r = int(self._cursor_block_count[si]) - int(self._cursor_next[si]) + remaining[i] = float(max(r, 1)) + else: + remaining[i] = float(self._base_block_counts[i]) + alpha = 0.90 - 0.40 * progress + weights = np.power(remaining, alpha) + ws = float(weights.sum()) + if not np.isfinite(ws) or ws <= 0.0: + weights = np.ones(ec, dtype=np.float64) + ws = float(weights.sum()) + probs = weights / ws + low = min(max(8, self.world_size), ec, gns) + high = min(max(32, self.world_size * 8), ec, gns) + mix = max(1, min(int(round(low + progress * (high - low))), ec, gns)) + cp = self._rng.choice(ec, size=mix, replace=False, p=probs) + cs = self._eligible_shards[cp] + cpr = probs[cp].copy() + cpr /= cpr.sum() + counts = np.ones(mix, dtype=np.int64) + extra = gns - mix + if extra > 0: + counts += self._rng.multinomial(extra, cpr).astype(np.int64) + perm = self._rng.permutation(mix) + cs, counts = cs[perm], counts[perm] + buckets: list[list[tuple[int, int]]] = [] + for si, cnt in zip(cs.tolist(), counts.tolist()): + b: list[tuple[int, int]] = [] + self._take_from_shard(int(si), seq_len, int(cnt), b) + if b: + if len(b) > 1: + bp = self._rng.permutation(len(b)) + b = [b[int(k)] for k in bp.tolist()] + buckets.append(b) + windows: list[tuple[int, int]] = [] + active = [i for i, bk in enumerate(buckets) if bk] + while active: + order = self._rng.permutation(len(active)) + new_active: list[int] = [] + for oi in order.tolist(): + bi = active[oi] + if buckets[bi]: + windows.append(buckets[bi].pop()) + if buckets[bi]: + new_active.append(bi) + active = new_active + return windows + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + if self._cfg is None: + self._init_pipeline(global_tokens, seq_len, grad_accum_steps) + _, _, num_seqs, _ = self._cfg + gw = self._sample_global_windows() + local_w = gw[self.rank::self.world_size] + x = torch.empty((num_seqs, seq_len), dtype=torch.int64) + y = torch.empty((num_seqs, seq_len), dtype=torch.int64) + for slot, (si, pos) in enumerate(local_w): + mm = _get_shard_memmap(self.files[si]) + window = torch.as_tensor(np.array(mm[pos:pos + seq_len + 1], dtype=np.int64)) + x[slot] = window[:-1] + y[slot] = window[1:] + self._batches_built += 1 + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ---------------------------------------- +# Model Architecture +# ---------------------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, + rope_base: float, qk_gain_init: float, train_seq_len: int): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len) + self.use_xsa = False + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if _use_flash_attn(): + y = _flash_attn_3_func(q, k, v, causal=True) + else: + y = F.scaled_dot_product_attention(q, k, v, is_causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + + +class ValueEmbedding(nn.Module): + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + return self.proj(F.leaky_relu(self.fc(x), negative_slope=0.5).square()) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, + rope_base: float, qk_gain_init: float, train_seq_len: int, + layer_idx: int = 0, ln_scale: bool = False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + return x_out + + +class GPT(nn.Module): + def __init__(self, h: Hyperparameters): + super().__init__() + self._ve_target_dim = h.num_kv_heads * (h.model_dim // h.num_heads) + if h.logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {h.logit_softcap}") + self.tie_embeddings = h.tie_embeddings + self.tied_embed_init_std = h.tied_embed_init_std + self.logit_softcap = h.logit_softcap + self.tok_emb = nn.Embedding(h.vocab_size, h.embedding_dim) + if h.embedding_dim != h.model_dim: + self.embed_proj = CastedLinear(h.embedding_dim, h.model_dim, bias=False) + self.head_proj = CastedLinear(h.model_dim, h.embedding_dim, bias=False) + else: + self.embed_proj = None + self.head_proj = None + self.num_encoder_layers = h.num_layers // 2 + self.num_decoder_layers = h.num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, h.model_dim, dtype=torch.float32)) + self.skip_gates = nn.Parameter(torch.zeros(self.num_skip_weights, h.model_dim, dtype=torch.float32)) if h.skip_gates_enabled else None + self.blocks = nn.ModuleList([ + Block(h.model_dim, h.num_heads, h.num_kv_heads, h.mlp_mult, h.rope_base, + h.qk_gain_init, h.train_seq_len, layer_idx=i, ln_scale=h.ln_scale) + for i in range(h.num_layers) + ]) + if h.rope_dims > 0: + head_dim = h.model_dim // h.num_heads + for block in self.blocks: + block.attn.rope_dims = h.rope_dims + block.attn.rotary = Rotary(head_dim, base=h.rope_base, train_seq_len=h.train_seq_len, rope_dims=h.rope_dims) + self.ve_layer_indices = [int(x) for x in h.ve_layers.split(",") if x.strip()] if h.ve_enabled else [] + kv_dim = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(h.vocab_size, h.ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() + self.final_norm = RMSNorm() + self.lm_head = None if h.tie_embeddings else CastedLinear(h.embedding_dim, h.vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + if h.xsa_last_n > 0: + for i in range(max(0, h.num_layers - h.xsa_last_n), h.num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + if self.embed_proj is not None: + x = self.embed_proj(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + scaled_skip = self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[i].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + x = self.final_norm(x) + if self.head_proj is not None: + x = self.head_proj(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self.forward_logits(input_ids) + return F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1), reduction="mean") + +# ---------------------------------------- +# Optimization +# ---------------------------------------- + +@torch.compile +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +class Optimizers(): + def __init__(self, h: Hyperparameters, base_model: GPT): + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in + CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in + CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.skip_gates is not None and base_model.skip_gates.numel() > 0: + scalar_params.append(base_model.skip_gates) + + token_lr = h.tied_embed_lr if h.tie_embeddings else h.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + + self.optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.embed_wd, + fused=True, + ) + self.optimizer_muon = Muon( + matrix_params, + lr=h.matrix_lr, + momentum=h.muon_momentum, + backend_steps=h.muon_backend_steps, + weight_decay=h.muon_wd, + ) + for group in self.optimizer_muon.param_groups: + group["base_lr"] = h.matrix_lr + self.optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": h.scalar_lr, "base_lr": h.scalar_lr}], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.adam_wd, + fused=True, + ) + self.optimizers: list[torch.optim.Optimizer] = [self.optimizer_tok, self.optimizer_muon, self.optimizer_scalar] + if base_model.lm_head is not None: + self.optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": h.head_lr, "base_lr": h.head_lr}], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + fused=True, + ) + self.optimizers.insert(1, self.optimizer_head) + else: + self.optimizer_head = None + + def __iter__(self): + return iter(self.optimizers) + + def zero_grad_all(self) -> None: + for opt in self.optimizers: + opt.zero_grad(set_to_none=True) + + def step(self): + for opt in self.optimizers: + opt.step() + self.zero_grad_all() + +# ---------------------------------------- +# Quantization +# ---------------------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates,ve_layer_scales,ve_shared.scale", + ).split(",") + if pattern +) +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +CODEBOOK_TARGET_PATTERNS = ( + "attn.c_q.weight", + "attn.c_k.weight", + "attn.c_v.weight", + "attn.proj.weight", + "mlp.fc.weight", + "mlp.proj.weight", +) + + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def dequantize_int8_rows(q: Tensor, s: Tensor, *, device: torch.device, dtype: torch.dtype) -> Tensor: + qf = q.to(device=device, dtype=torch.float32) + sf = s.to(device=device, dtype=torch.float32) + if sf.ndim > 0: + return (qf * sf.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype) + return (qf * float(sf.item())).to(dtype=dtype) + + +def restore_fp32_params(model: nn.Module) -> None: + """After .bfloat16(), restore CastedLinear weights and control params to FP32.""" + for module in model.modules(): + if isinstance(module, CastedLinear): + module.float() + for name, param in model.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +def _is_power_of_two(n: int) -> bool: + return n > 0 and (n & (n - 1)) == 0 + + +def _validate_codebook_hparams(h: Hyperparameters) -> None: + if not _is_power_of_two(h.codebook_block_dim): + raise ValueError(f"CODEBOOK_BLOCK_DIM must be a power of 2, got {h.codebook_block_dim}") + if h.codebook_block_dim != 8: + raise ValueError(f"The E8P codebook expects CODEBOOK_BLOCK_DIM=8, got {h.codebook_block_dim}") + if h.codebook_calibration_batches < 0: + raise ValueError( + f"CODEBOOK_CALIBRATION_BATCHES must be non-negative, got {h.codebook_calibration_batches}" + ) + if h.codebook_reserve_seconds < 0: + raise ValueError( + f"CODEBOOK_RESERVE_SECONDS must be non-negative, got {h.codebook_reserve_seconds}" + ) + if h.codebook_hessian_damp < 0: + raise ValueError( + f"CODEBOOK_HESSIAN_DAMP must be non-negative, got {h.codebook_hessian_damp}" + ) + if h.codebook_lattice_scale <= 0: + raise ValueError(f"CODEBOOK_LATTICE_SCALE must be positive, got {h.codebook_lattice_scale}") + if h.codebook_scale_bits <= 0 or h.codebook_scale_bits > 16: + raise ValueError(f"CODEBOOK_SCALE_BITS must be in [1, 16], got {h.codebook_scale_bits}") + if h.codebook_assignment_entropy_weight < 0: + raise ValueError( + f"CODEBOOK_ASSIGNMENT_ENTROPY_WEIGHT must be non-negative, got {h.codebook_assignment_entropy_weight}" + ) + if not 0.0 <= h.codebook_soft_snap_start_frac <= 1.0: + raise ValueError( + f"CODEBOOK_SOFT_SNAP_START_FRAC must be in [0, 1], got {h.codebook_soft_snap_start_frac}" + ) + if not 0.0 <= h.codebook_soft_snap_alpha <= 1.0: + raise ValueError(f"CODEBOOK_SOFT_SNAP_ALPHA must be in [0, 1], got {h.codebook_soft_snap_alpha}") + if h.codebook_soft_snap_update_every <= 0: + raise ValueError( + f"CODEBOOK_SOFT_SNAP_UPDATE_EVERY must be positive, got {h.codebook_soft_snap_update_every}" + ) + if not 0.0 <= h.codebook_penalty_start_frac <= 1.0: + raise ValueError( + f"CODEBOOK_PENALTY_START_FRAC must be in [0, 1], got {h.codebook_penalty_start_frac}" + ) + if h.codebook_penalty_weight < 0: + raise ValueError( + f"CODEBOOK_PENALTY_WEIGHT must be non-negative, got {h.codebook_penalty_weight}" + ) + if h.codebook_penalty_refresh_every <= 0: + raise ValueError( + f"CODEBOOK_PENALTY_REFRESH_EVERY must be positive, got {h.codebook_penalty_refresh_every}" + ) + if not 0.0 <= h.codebook_outlier_frac <= 1.0: + raise ValueError(f"CODEBOOK_OUTLIER_FRAC must be in [0, 1], got {h.codebook_outlier_frac}") + if h.codebook_outlier_max_count < 0: + raise ValueError(f"CODEBOOK_OUTLIER_MAX_COUNT must be non-negative, got {h.codebook_outlier_max_count}") + + +def _blockify_weight(t: Tensor, block_dim: int) -> tuple[Tensor, tuple[int, int]]: + if t.ndim != 2: + raise ValueError(f"Codebook quantization expects a 2D tensor, got shape {tuple(t.shape)}") + if t.shape[1] % block_dim != 0: + raise ValueError( + f"Tensor shape {tuple(t.shape)} is not compatible with CODEBOOK_BLOCK_DIM={block_dim}" + ) + return t.contiguous().view(-1, block_dim), (int(t.shape[0]), int(t.shape[1])) + + +def _unblockify_weight(blocks: Tensor, original_shape: tuple[int, int]) -> Tensor: + return blocks.contiguous().view(original_shape) + + +def _should_codebook_quantize(name: str, t: Tensor, h: Hyperparameters) -> bool: + return ( + t.is_floating_point() + and t.ndim == 2 + and t.shape[1] % h.codebook_block_dim == 0 + and name.startswith("blocks.") + and any(name.endswith(pattern) for pattern in CODEBOOK_TARGET_PATTERNS) + ) +def _hadamard_sign_vector(name: str, block_dim: int, *, device: torch.device, dtype: torch.dtype) -> Tensor: + rng = random.Random(f"hadamard::{name}::{block_dim}") + return torch.tensor( + [1.0 if rng.getrandbits(1) else -1.0 for _ in range(block_dim)], + device=device, + dtype=dtype, + ) + + +def hadamard_transform(x: Tensor, scale: float = 1.0) -> Tensor: + original_shape = x.shape + dim = x.shape[-1] + padded_dim = 1 << (max(dim, 1) - 1).bit_length() + out = x.reshape(-1, dim) + if padded_dim != dim: + out = F.pad(out, (0, padded_dim - dim)) + h = 1 + while h < padded_dim: + out = out.view(-1, padded_dim // (2 * h), 2, h) + a = out[:, :, 0, :] + b = out[:, :, 1, :] + out = torch.stack((a + b, a - b), dim=2).reshape(-1, padded_dim) + h *= 2 + out = out[:, :dim] + return (out * scale).reshape(*original_shape) + + +def hadamard_rotate_blocks(blocks: Tensor, sign_vec: Tensor, *, enabled: bool = True) -> Tensor: + if not enabled: + return blocks + return hadamard_transform(blocks * sign_vec, scale=blocks.shape[-1] ** -0.5) + + +def hadamard_unrotate_blocks(blocks: Tensor, sign_vec: Tensor, *, enabled: bool = True) -> Tensor: + if not enabled: + return blocks + return hadamard_transform(blocks, scale=blocks.shape[-1] ** -0.5) * sign_vec + + +def get_norm12() -> Tensor: + return torch.tensor([ + [3, 1, 1, 1, 3, 3, 3, 3], + [1, 3, 1, 1, 3, 3, 3, 3], + [1, 1, 3, 1, 3, 3, 3, 3], + [1, 1, 1, 3, 3, 3, 3, 3], + [3, 3, 3, 1, 3, 3, 1, 1], + [3, 3, 3, 1, 3, 1, 3, 1], + [3, 3, 3, 1, 1, 3, 3, 1], + [3, 3, 3, 1, 3, 1, 1, 3], + [3, 3, 3, 1, 1, 3, 1, 3], + [3, 3, 3, 1, 1, 1, 3, 3], + [3, 3, 1, 3, 3, 3, 1, 1], + [3, 3, 1, 3, 3, 1, 3, 1], + [3, 3, 1, 3, 1, 3, 3, 1], + [3, 3, 1, 3, 3, 1, 1, 3], + [3, 3, 1, 3, 1, 3, 1, 3], + [3, 3, 1, 3, 1, 1, 3, 3], + [3, 1, 3, 3, 3, 3, 1, 1], + [3, 1, 3, 3, 3, 1, 3, 1], + [3, 1, 3, 3, 1, 3, 3, 1], + [3, 1, 3, 3, 3, 1, 1, 3], + [3, 1, 3, 3, 1, 3, 1, 3], + [1, 3, 3, 3, 1, 1, 3, 3], + [1, 3, 3, 3, 3, 3, 1, 1], + [1, 3, 3, 3, 3, 1, 3, 1], + [1, 3, 3, 3, 1, 3, 3, 1], + [1, 3, 3, 3, 3, 1, 1, 3], + [1, 3, 3, 3, 1, 3, 1, 3], + [1, 1, 3, 3, 1, 3, 3, 3], + [3, 3, 1, 1, 3, 3, 3, 1], + ], dtype=torch.float32) / 2 + + +def get_packed_abs_grid() -> Tensor: + intr = torch.arange(-4, 4) + d8 = torch.cartesian_prod(*[intr] * 8).float() + 0.5 + d8_even = d8.sum(dim=-1) % 2 == 0 + d8_small = d8.norm(dim=-1).square() <= 10 + d8_abs = torch.unique(d8[torch.where(d8_even & d8_small)[0]].abs(), dim=0) + codebook_abs = torch.cat([d8_abs, get_norm12()], dim=0) + codebook_abs = codebook_abs[:, [0, 2, 4, 6, 1, 3, 5, 7]] + codebook_abs[:, 7] *= (1 - 2 * (codebook_abs.sum(1) % 2)) + codebook_abs = (codebook_abs * 2 + 8).to(torch.int32) + acc = codebook_abs[:, 0] + for i in range(7): + acc = acc | (codebook_abs[:, i + 1] << ((i + 1) * 4)) + return acc + + +def get_abs_grid() -> Tensor: + intr = torch.arange(-4, 4) + d8 = torch.cartesian_prod(*[intr] * 8).float() + 0.5 + d8_even = d8.sum(dim=-1) % 2 == 0 + d8_small = d8.norm(dim=-1).square() <= 10 + d8_abs = torch.unique(d8[torch.where(d8_even & d8_small)[0]].abs(), dim=0) + return torch.cat([d8_abs, get_norm12()], dim=0) + + +def get_full_grid(packed_abs_grid: Tensor) -> tuple[Tensor, Tensor]: + codebook = torch.zeros(1 << 16, 8) + parity_idx: list[int] = [] + shuffle_map = [0, 4, 1, 5, 2, 6, 3, 7] + for code in range(1 << 16): + signs = code & 255 + abs_idx = code >> 8 + parity = 0 + for bit in range(8): + parity ^= (signs >> bit) & 1 + signs ^= parity + abs_code = packed_abs_grid[abs_idx].item() + for i, shuffled in enumerate(shuffle_map): + codebook[code, i] = (((abs_code >> (4 * shuffled)) & 15) - 8) * 0.5 + if (signs >> shuffled) & 1: + codebook[code, i] *= -1 + if parity: + codebook[code] -= 0.25 + parity_idx.append(code) + else: + codebook[code] += 0.25 + return codebook, torch.tensor(parity_idx, dtype=torch.long) + + +_E8P_PACKED_ABS = get_packed_abs_grid() +_E8P_GRID, _E8P_PARITY_IDX = get_full_grid(_E8P_PACKED_ABS) + + +class E8P12Codebook(nn.Module): + def __init__(self) -> None: + super().__init__() + self.codesz = 8 + self.register_buffer("grid", _E8P_GRID) + self.register_buffer("grid_norm", _E8P_GRID.norm(dim=-1).square()) + grid_part = _E8P_GRID[_E8P_PARITY_IDX] + 0.25 + grid_part = grid_part[ + torch.where( + ((grid_part[:, :7] < 0).sum(dim=-1) <= 1) + & (grid_part[:, :7].min(dim=-1).values >= -0.5) + )[0] + ] + abs_grid = get_abs_grid() + self.register_buffer("grid_part", grid_part) + self.register_buffer("grid_part_norm", grid_part.norm(dim=-1).square()) + self.register_buffer("grid_abs_odd", abs_grid.sum(dim=-1) % 2 == 1) + self.register_buffer("part_abs_map", self.round(grid_part.abs(), abs_grid, abs_grid.norm(dim=-1).square())[1]) + self.register_buffer("bit_map", 2 ** torch.arange(8)) + + def round(self, x: Tensor, grid: Tensor, grid_norm: Tensor) -> tuple[Tensor, Tensor]: + idx = (2 * x @ grid.t() - grid_norm).argmax(dim=-1) + return grid[idx], idx + + def round_topk(self, x: Tensor, grid: Tensor, grid_norm: Tensor, k: int = 2) -> tuple[Tensor, Tensor]: + scores = 2 * x @ grid.t() - grid_norm + top_idx = scores.topk(k, dim=-1).indices + return grid[top_idx], top_idx + + def fast_quantize_part(self, x: Tensor, parity: bool) -> tuple[Tensor, Tensor, Tensor]: + x_part = torch.abs(x) + x_odd = torch.where((x < 0).sum(dim=-1) % 2 != 0)[0] + x_part[x_odd, 7] = -x_part[x_odd, 7] + mask = 1 - 2 * (x < 0).to(torch.float32) + mask[x_odd, 7] = -mask[x_odd, 7] + rounded, rounded_idx = self.round(x_part, self.grid_part, self.grid_part_norm) + vals = rounded * mask + err = (x - vals).norm(dim=-1) + abs_idx = self.part_abs_map[rounded_idx] + sign_mask = ((rounded < 0) ^ (mask < 0))[:, [0, 2, 4, 6, 1, 3, 5, 7]] + sign_mask[:, 7] ^= self.grid_abs_odd[abs_idx] + sign_mask[:, 0] ^= parity + idx = (abs_idx << 8) + (sign_mask * self.bit_map).sum(dim=-1).int() + return vals, idx, err + + def fast_quantize_part_topk(self, x: Tensor, parity: bool, k: int = 2) -> tuple[Tensor, Tensor, Tensor]: + x_part = torch.abs(x) + x_odd = torch.where((x < 0).sum(dim=-1) % 2 != 0)[0] + x_part[x_odd, 7] = -x_part[x_odd, 7] + mask = 1 - 2 * (x < 0).to(torch.float32) + mask[x_odd, 7] = -mask[x_odd, 7] + rounded, rounded_idx = self.round_topk(x_part, self.grid_part, self.grid_part_norm, k=k) + vals = rounded * mask.unsqueeze(1) + err = (x.unsqueeze(1) - vals).square().sum(dim=-1) + abs_idx = self.part_abs_map[rounded_idx] + sign_mask = ((rounded < 0) ^ (mask.unsqueeze(1) < 0))[:, :, [0, 2, 4, 6, 1, 3, 5, 7]] + sign_mask[:, :, 7] ^= self.grid_abs_odd[abs_idx] + sign_mask[:, :, 0] ^= parity + idx = (abs_idx << 8) + (sign_mask * self.bit_map.view(1, 1, -1)).sum(dim=-1).int() + return vals, idx, err + + def quantize(self, x: Tensor) -> tuple[Tensor, Tensor]: + plus_vals, plus_idx, plus_err = self.fast_quantize_part(x + 0.25, True) + minus_vals, minus_idx, minus_err = self.fast_quantize_part(x - 0.25, False) + use_plus = plus_err < minus_err + vals = torch.where(use_plus.unsqueeze(-1), plus_vals - 0.25, minus_vals + 0.25) + idx = torch.where(use_plus, plus_idx, minus_idx) + return vals, idx + + def quantize_top2(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: + plus_vals, plus_idx, plus_err = self.fast_quantize_part_topk(x + 0.25, True, k=2) + minus_vals, minus_idx, minus_err = self.fast_quantize_part_topk(x - 0.25, False, k=2) + cand_vals = torch.cat((plus_vals - 0.25, minus_vals + 0.25), dim=1) + cand_idx = torch.cat((plus_idx, minus_idx), dim=1) + cand_err = torch.cat((plus_err, minus_err), dim=1) + order = cand_err.topk(2, dim=1, largest=False).indices + gather_vals = cand_vals.gather(1, order.unsqueeze(-1).expand(-1, -1, self.codesz)) + gather_idx = cand_idx.gather(1, order) + gather_err = cand_err.gather(1, order) + return gather_vals[:, 0], gather_idx[:, 0], gather_vals[:, 1], gather_err[:, 1] - gather_err[:, 0] + + def decode(self, idxs: Tensor) -> Tensor: + idxs_long = idxs.long().reshape(-1) + vals = self.grid.index_select(0, idxs_long) + return vals.view(*idxs.shape, self.codesz) + + +_E8P_CODEBOOK_CACHE: dict[str, E8P12Codebook] = {} + + +def _e8p_cache_key(device: torch.device | None) -> str: + if device is None: + return "cpu" + return f"{device.type}:{device.index}" + + +def _get_e8p_codebook(device: torch.device | None) -> E8P12Codebook: + key = _e8p_cache_key(device) + cached = _E8P_CODEBOOK_CACHE.get(key) + if cached is not None: + return cached + codebook = E8P12Codebook() + if device is not None: + codebook = codebook.to(device=device, dtype=torch.float32) + else: + codebook = codebook.to(dtype=torch.float32) + codebook.eval() + _E8P_CODEBOOK_CACHE[key] = codebook + return codebook + + +@torch.no_grad() +def _quantize_e8p_blocks(blocks: Tensor, lattice_scale: float) -> tuple[Tensor, Tensor]: + codebook = _get_e8p_codebook(blocks.device) + scaled = blocks.to(dtype=torch.float32) * float(lattice_scale) + vals, idxs = codebook.quantize(scaled) + return vals / float(lattice_scale), idxs.long() + + +@torch.no_grad() +def _decode_e8p_blocks( + idxs: Tensor, + lattice_scale: float, + *, + device: torch.device, + dtype: torch.dtype, +) -> Tensor: + codebook = _get_e8p_codebook(device) + vals = codebook.decode(idxs.to(device=device)) + return (vals / float(lattice_scale)).to(dtype=dtype) + + +@torch.no_grad() +def _fast_quantize_weight_target(h: Hyperparameters, name: str, weight: Tensor) -> Tensor: + weight_f32 = weight.detach().float() + blocks, shape = _blockify_weight(weight_f32, h.codebook_block_dim) + sign_vec = _hadamard_sign_vector(name, h.codebook_block_dim, device=weight.device, dtype=torch.float32) + rotated_blocks = hadamard_rotate_blocks(blocks, sign_vec, enabled=h.codebook_use_hadamard) + block_norms = rotated_blocks.norm(dim=-1, keepdim=True).clamp_min(1e-8) + codebook = _get_e8p_codebook(weight.device) + target_radius = float(codebook.grid.norm(dim=-1).median().item()) + proxy_blocks = rotated_blocks / block_norms * target_radius + quantized_proxy, _ = _quantize_e8p_blocks(proxy_blocks, h.codebook_lattice_scale) + quantized_dirs = F.normalize(quantized_proxy.to(dtype=torch.float32), dim=-1) + recon_rotated = quantized_dirs * block_norms + recon_blocks = hadamard_unrotate_blocks( + recon_rotated, + sign_vec, + enabled=h.codebook_use_hadamard, + ) + return _unblockify_weight(recon_blocks, shape).to(dtype=weight.dtype) + + +def _log_prior_from_assignments( + assignments: Tensor, + num_entries: int, + *, + smoothing: float = 1.0, +) -> Tensor: + flat = assignments.reshape(-1).to(dtype=torch.int64) + counts = torch.bincount(flat, minlength=num_entries).to(dtype=torch.float32) + probs = (counts + float(smoothing)) / float(counts.sum().item() + smoothing * num_entries) + return torch.log(probs) + + +@torch.no_grad() +def _metric_optimal_codewords( + rotated_blocks: Tensor, + metrics: Tensor, + candidates: Tensor, + *, + candidate_batch_size: int = 4096, + assignment_log_prior: Tensor | None = None, + assignment_entropy_weight: float = 0.0, +) -> tuple[Tensor, Tensor]: + if rotated_blocks.ndim != 3: + raise ValueError(f"Expected rotated blocks with shape [rows, positions, dim], got {tuple(rotated_blocks.shape)}") + expected_metrics_shape = (rotated_blocks.shape[1], rotated_blocks.shape[2], rotated_blocks.shape[2]) + if metrics.ndim != 3 or metrics.shape != expected_metrics_shape: + raise ValueError(f"Expected metrics with shape {expected_metrics_shape}, got {tuple(metrics.shape)}") + if candidates.ndim != 2 or candidates.shape[1] != rotated_blocks.shape[2]: + raise ValueError(f"Expected candidates [entries, {rotated_blocks.shape[2]}], got {tuple(candidates.shape)}") + if assignment_log_prior is not None and assignment_log_prior.shape != (candidates.shape[0],): + raise ValueError( + f"Expected assignment_log_prior with shape ({candidates.shape[0]},), got {tuple(assignment_log_prior.shape)}" + ) + num_rows, num_positions, _ = rotated_blocks.shape + selected_idx = torch.empty((num_rows, num_positions), device=rotated_blocks.device, dtype=torch.long) + selected_codewords = torch.empty_like(rotated_blocks, dtype=torch.float32) + for pos in range(num_positions): + metric = metrics[pos] + x = rotated_blocks[:, pos, :].to(dtype=torch.float32) + x_metric = torch.matmul(x, metric) + best_improvement = torch.full((num_rows,), -torch.inf, device=x.device, dtype=torch.float32) + best_idx = torch.zeros((num_rows,), device=x.device, dtype=torch.long) + for start in range(0, candidates.shape[0], candidate_batch_size): + cand = candidates[start:start + candidate_batch_size] + cand_metric = torch.matmul(cand, metric) + numer = torch.matmul(x_metric, cand.t()) + denom = (cand_metric * cand).sum(dim=-1).clamp_min(1e-8) + # x^T M x is candidate-independent, so minimizing the optimal metric error is + # equivalent to maximizing the gain from the best non-negative scale. + improvement = numer.clamp_min(0.0).square() / denom.unsqueeze(0) + if assignment_log_prior is not None and assignment_entropy_weight > 0: + improvement = improvement + float(assignment_entropy_weight) * assignment_log_prior[ + start:start + cand.shape[0] + ].unsqueeze(0) + chunk_best_improvement, chunk_best_idx = improvement.max(dim=1) + update = chunk_best_improvement > best_improvement + if update.any(): + best_improvement = torch.where(update, chunk_best_improvement, best_improvement) + best_idx = torch.where(update, start + chunk_best_idx, best_idx) + selected_idx[:, pos] = best_idx + selected_codewords[:, pos, :] = candidates.index_select(0, best_idx) + return selected_codewords, selected_idx + + +def collect_hessians( + model: nn.Module, + train_loader: DistributedTokenLoader, + h: Hyperparameters, + device: torch.device, + target_names: set[str], + n_calibration_batches: int = 64, +) -> dict[str, Tensor]: + """Run calibration batches and collect H = X^T X for selected CastedLinear layers.""" + if n_calibration_batches <= 0 or not target_names: + return {} + hessians: dict[str, Tensor] = {} + hooks = [] + was_training = model.training + + def make_hook(name: str): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + return hook_fn + + for module_name, module in model.named_modules(): + if not isinstance(module, CastedLinear): + continue + weight_name = f"{module_name}.weight" + if weight_name in target_names: + hooks.append(module.register_forward_hook(make_hook(weight_name))) + + model.eval() + with torch.no_grad(): + for _ in range(n_calibration_batches): + x, _ = train_loader.next_batch( + h.train_batch_tokens, + h.train_seq_len, + h.grad_accum_steps, + ) + model.forward_logits(x) + + for hook in hooks: + hook.remove() + + world = dist.get_world_size() if dist.is_available() and dist.is_initialized() else 1 + for name, hessian in hessians.items(): + if world > 1: + dist.all_reduce(hessian, op=dist.ReduceOp.SUM) + hessians[name] = hessian.cpu() / float(max(world * n_calibration_batches, 1)) + + if was_training: + model.train() + return hessians + + +def _identity_block_metrics(shape: tuple[int, int], block_dim: int, device: torch.device) -> Tensor: + num_positions = shape[1] // block_dim + eye = torch.eye(block_dim, device=device, dtype=torch.float32) + return eye.unsqueeze(0).repeat(num_positions, 1, 1) + + +def _block_metrics_from_hessian( + shape: tuple[int, int], + block_dim: int, + sign_vec: Tensor, + hessian: Tensor | None, + *, + use_hadamard: bool, + damp_factor: float, + device: torch.device, +) -> Tensor: + if hessian is None or hessian.ndim != 2 or hessian.shape != (shape[1], shape[1]): + return _identity_block_metrics(shape, block_dim, device) + num_positions = shape[1] // block_dim + hessian_work = hessian.to(device=device, dtype=torch.float32).clone() + diag = hessian_work.diag() + dead = diag <= 0 + if dead.any(): + hessian_work[dead, dead] = 1.0 + damp = float(damp_factor * hessian_work.diag().mean().clamp_min(1e-8).item()) + hessian_work.diagonal().add_(damp) + reshaped = hessian_work.view(num_positions, block_dim, num_positions, block_dim) + block_idx = torch.arange(num_positions, device=device) + block_metrics = reshaped[block_idx, :, block_idx, :] + if use_hadamard: + unrotate = hadamard_unrotate_blocks( + torch.eye(block_dim, device=device, dtype=torch.float32), + sign_vec, + enabled=True, + ) + block_metrics = torch.matmul( + unrotate.unsqueeze(0), + torch.matmul(block_metrics, unrotate.t().unsqueeze(0)), + ) + block_metrics = 0.5 * (block_metrics + block_metrics.transpose(-1, -2)) + block_metrics.diagonal(dim1=-2, dim2=-1).add_(1e-6) + return block_metrics + + +def _metric_optimal_scales(rotated_blocks: Tensor, codebook_blocks: Tensor, metrics: Tensor) -> Tensor: + numer = torch.einsum("npd,pde,npe->np", codebook_blocks, metrics, rotated_blocks) + denom = torch.einsum("npd,pde,npe->np", codebook_blocks, metrics, codebook_blocks).clamp_min(1e-8) + scales = (numer / denom).clamp_min(1e-8) + max_fp16 = float(torch.finfo(torch.float16).max) + return scales.clamp(max=max_fp16).unsqueeze(-1) + + +@torch.no_grad() +def _metric_optimal_e8p_codewords( + rotated_blocks: Tensor, + metrics: Tensor, + lattice_scale: float, + *, + candidate_batch_size: int = 4096, + assignment_log_prior: Tensor | None = None, + assignment_entropy_weight: float = 0.0, +) -> tuple[Tensor, Tensor]: + codebook = _get_e8p_codebook(rotated_blocks.device) + candidates = codebook.grid.to(dtype=torch.float32) / float(lattice_scale) + return _metric_optimal_codewords( + rotated_blocks, + metrics, + candidates, + candidate_batch_size=candidate_batch_size, + assignment_log_prior=assignment_log_prior, + assignment_entropy_weight=assignment_entropy_weight, + ) + + +class FixedCodebookSoftSnap: + def __init__(self, h: Hyperparameters, model: nn.Module): + _validate_codebook_hparams(h) + self.h = h + self.target_modules: list[tuple[str, CastedLinear]] = [] + self.target_weights = 0 + self.active = False + self.last_snap_step: int | None = None + for module_name, module in model.named_modules(): + if not isinstance(module, CastedLinear): + continue + weight_name = f"{module_name}.weight" + if _should_codebook_quantize(weight_name, module.weight, h): + self.target_modules.append((weight_name, module)) + self.target_weights += int(module.weight.numel()) + + def should_activate(self, frac: float) -> bool: + return self.h.codebook_soft_snap_enabled and not self.active and frac >= self.h.codebook_soft_snap_start_frac + + @torch.no_grad() + def apply(self, step: int, *, force: bool = False) -> bool: + if not self.active or not self.target_modules: + return False + if not force and self.last_snap_step is not None: + if (step - self.last_snap_step) < self.h.codebook_soft_snap_update_every: + return False + alpha = float(self.h.codebook_soft_snap_alpha) + if alpha <= 0.0: + return False + for name, module in self.target_modules: + q_weight = _fast_quantize_weight_target(self.h, name, module.weight) + module.weight.lerp_(q_weight, alpha) + self.last_snap_step = step + return True + + def activate(self, step: int, frac: float) -> None: + if self.active: + return + self.active = True + if self.h.is_main_process: + log( + f"codebook_soft_snap:enabled step:{step}/{self.h.iterations} " + f"start_frac:{self.h.codebook_soft_snap_start_frac:.3f} frac:{frac:.3f} " + f"mode:fast_direction " + f"alpha:{self.h.codebook_soft_snap_alpha:.4f} " + f"target_tensors:{len(self.target_modules)} target_weights:{self.target_weights} " + f"update_every:{self.h.codebook_soft_snap_update_every}" + ) + + +class FixedCodebookPenalty: + def __init__(self, h: Hyperparameters, model: nn.Module): + _validate_codebook_hparams(h) + self.h = h + self.target_modules: list[tuple[str, CastedLinear]] = [] + self.target_weights = 0 + self.active = False + self.last_refresh_step: int | None = None + self.cached_targets: list[dict[str, object]] = [] + for module_name, module in model.named_modules(): + if not isinstance(module, CastedLinear): + continue + weight_name = f"{module_name}.weight" + if _should_codebook_quantize(weight_name, module.weight, h): + self.target_modules.append((weight_name, module)) + self.target_weights += int(module.weight.numel()) + + def should_activate(self, frac: float) -> bool: + return self.h.codebook_penalty_enabled and not self.active and frac >= self.h.codebook_penalty_start_frac + + def activate(self, step: int, frac: float) -> None: + if self.active: + return + self.active = True + if self.h.is_main_process: + log( + f"codebook_penalty:enabled step:{step}/{self.h.iterations} " + f"start_frac:{self.h.codebook_penalty_start_frac:.3f} frac:{frac:.3f} " + f"weight:{self.h.codebook_penalty_weight:.5f} " + f"mode:fast_direction_target " + f"refresh_every:{self.h.codebook_penalty_refresh_every} " + f"target_tensors:{len(self.target_modules)} target_weights:{self.target_weights}" + ) + + def _current_weight(self, frac: float) -> float: + if frac <= self.h.codebook_penalty_start_frac: + return 0.0 + return float(self.h.codebook_penalty_weight) + + @torch.no_grad() + def _refresh_targets(self, step: int) -> None: + refreshed: list[dict[str, object]] = [] + for name, module in self.target_modules: + q_weight = _fast_quantize_weight_target(self.h, name, module.weight) + refreshed.append( + { + "name": name, + "module": module, + "target_weight": q_weight.detach(), + } + ) + self.cached_targets = refreshed + self.last_refresh_step = step + + def penalty(self, step: int, frac: float) -> Tensor | None: + if not self.active: + return None + weight = self._current_weight(frac) + if weight <= 0.0: + return None + if self.last_refresh_step is None or (step - self.last_refresh_step) >= self.h.codebook_penalty_refresh_every: + self._refresh_targets(step) + if not self.cached_targets: + return None + losses: list[Tensor] = [] + for item in self.cached_targets: + module = item["module"] + target_weight = item["target_weight"] + diff = module.weight.float() - target_weight.float() + mse = diff.square().mean() + denom = target_weight.float().square().mean().clamp_min(1e-8) + losses.append(mse / denom) + if not losses: + return None + return float(weight) * torch.stack(losses).mean() + + +class CodebookQuantizer: + def __init__(self, h: Hyperparameters, model: nn.Module): + _validate_codebook_hparams(h) + self.h = h + self.states: dict[str, dict[str, object]] = {} + self.target_modules: list[tuple[str, CastedLinear]] = [] + self.target_names: set[str] = set() + self.last_fit_summary: dict[str, float] = {} + for module_name, module in model.named_modules(): + if not isinstance(module, CastedLinear): + continue + weight_name = f"{module_name}.weight" + if _should_codebook_quantize(weight_name, module.weight, h): + self.target_modules.append((weight_name, module)) + self.target_names.add(weight_name) + + @torch.no_grad() + def _prepare_rotated_blocks( + self, + name: str, + module: CastedLinear, + ) -> tuple[Tensor, tuple[int, int], Tensor]: + weight = module.weight.detach().float() + blocks, shape = _blockify_weight(weight, self.h.codebook_block_dim) + sign_vec = _hadamard_sign_vector(name, self.h.codebook_block_dim, device=blocks.device, dtype=torch.float32) + rotated_blocks = hadamard_rotate_blocks(blocks, sign_vec, enabled=self.h.codebook_use_hadamard) + return rotated_blocks, shape, sign_vec + + def _fit_tensor(self, name: str, module: CastedLinear, hessian: Tensor | None) -> tuple[str, dict[str, float]]: + weight = module.weight.detach().float() + rotated_blocks, shape, sign_vec = self._prepare_rotated_blocks(name, module) + blocks, _ = _blockify_weight(weight, self.h.codebook_block_dim) + num_rows, num_cols = shape + num_positions = num_cols // self.h.codebook_block_dim + rotated_grid = rotated_blocks.view(num_rows, num_positions, self.h.codebook_block_dim) + metrics = _block_metrics_from_hessian( + shape, + self.h.codebook_block_dim, + sign_vec, + hessian, + use_hadamard=self.h.codebook_use_hadamard, + damp_factor=self.h.codebook_hessian_damp, + device=blocks.device, + ) + initial_dirs, initial_idx = _metric_optimal_e8p_codewords( + rotated_grid, + metrics, + self.h.codebook_lattice_scale, + ) + if self.h.codebook_assignment_entropy_weight > 0: + e8p_entries = int(_get_e8p_codebook(rotated_blocks.device).grid.shape[0]) + assignment_log_prior = _log_prior_from_assignments(initial_idx, e8p_entries) + fixed_dirs, fixed_idx = _metric_optimal_e8p_codewords( + rotated_grid, + metrics, + self.h.codebook_lattice_scale, + assignment_log_prior=assignment_log_prior, + assignment_entropy_weight=self.h.codebook_assignment_entropy_weight, + ) + else: + assignment_log_prior = None + fixed_dirs, fixed_idx = initial_dirs, initial_idx + scales = _metric_optimal_scales(rotated_grid, fixed_dirs, metrics) + scales_fp16 = scales.to(dtype=torch.float16) + recon_rotated = fixed_dirs * scales_fp16.to(dtype=torch.float32) + recon_blocks = hadamard_unrotate_blocks( + recon_rotated.view(-1, self.h.codebook_block_dim), + sign_vec, + enabled=self.h.codebook_use_hadamard, + ) + diff = blocks - recon_blocks + mse = diff.square().mean() + energy = blocks.square().mean().clamp_min(1e-12) + stats = { + "num_weights": float(weight.numel()), + "mse": float(mse.item()), + "rel_mse": float((mse / energy).item()), + "scale_min": float(scales_fp16.min().item()), + "scale_mean": float(scales_fp16.float().mean().item()), + "scale_max": float(scales_fp16.max().item()), + } + state: dict[str, object] = { + "shape": shape, + "block_dim": self.h.codebook_block_dim, + "fixed_idx": fixed_idx.reshape(-1).to(dtype=torch.uint16).cpu().contiguous(), + "scales": scales_fp16.reshape(-1).cpu().contiguous(), + "stats": stats, + } + if assignment_log_prior is not None: + probs = assignment_log_prior.exp() + stats["assignment_prior_entropy"] = float((-(probs * assignment_log_prior).sum()).item()) + self.states[name] = state + return name, stats + + @torch.no_grad() + def fit(self, model: nn.Module, hessians: dict[str, Tensor]) -> None: + if not self.target_modules: + if self.h.is_main_process: + log("codebook:no eligible tensors found") + return + total_fp_weights = sum( + int(t.numel()) + for _, t in model.state_dict().items() + if t.is_floating_point() + ) + fit_items = [self._fit_tensor(name, module, hessians.get(name)) for name, module in self.target_modules] + target_weights = sum(int(stats["num_weights"]) for _, stats in fit_items) + weighted_rel_mse = sum(float(stats["rel_mse"]) * int(stats["num_weights"]) for _, stats in fit_items) + self.last_fit_summary = { + "target_tensors": float(len(fit_items)), + "target_weights": float(target_weights), + "coverage": target_weights / max(total_fp_weights, 1), + "rel_mse": weighted_rel_mse / max(target_weights, 1), + "target_bpw": 4.0, + } + if self.h.is_main_process: + log( + f"codebook:fit target_tensors:{len(fit_items)} target_weights:{target_weights} " + f"coverage:{self.last_fit_summary['coverage']:.4%} rel_mse:{self.last_fit_summary['rel_mse']:.6e} " + f"target_bpw:{self.last_fit_summary['target_bpw']:.4f}" + ) + + @torch.no_grad() + def _reconstruct_tensor_from_state(self, name: str, state: dict[str, object], *, dtype: torch.dtype = torch.float32) -> Tensor: + shape = tuple(int(x) for x in state["shape"]) + block_dim = int(state["block_dim"]) + idx = state["fixed_idx"].to(dtype=torch.int64) + scales = state["scales"].to(dtype=torch.float32).view(-1, 1) + fixed_blocks = _decode_e8p_blocks( + idx, + float(self.h.codebook_lattice_scale), + device=torch.device("cpu"), + dtype=torch.float32, + ) + rotated_blocks = fixed_blocks * scales + sign_vec = _hadamard_sign_vector(name, block_dim, device=torch.device("cpu"), dtype=torch.float32) + blocks = hadamard_unrotate_blocks( + rotated_blocks, + sign_vec, + enabled=bool(self.h.codebook_use_hadamard), + ) + return _unblockify_weight(blocks, shape).to(dtype) + + @torch.no_grad() + def _select_outliers(self, original: Tensor, reconstructed: Tensor) -> tuple[Tensor | None, Tensor | None, Tensor | None]: + requested = 0 + if self.h.codebook_outlier_frac > 0: + requested = int(math.ceil(float(self.h.codebook_outlier_frac) * float(original.numel()))) + if self.h.codebook_outlier_max_count > 0: + requested = ( + min(requested, int(self.h.codebook_outlier_max_count)) + if requested > 0 + else int(self.h.codebook_outlier_max_count) + ) + requested = min(requested, int(original.numel())) + if requested <= 0: + return None, None, None + flat_orig = original.reshape(-1).to(dtype=torch.float32) + flat_recon = reconstructed.reshape(-1).to(dtype=torch.float32) + err = (flat_orig - flat_recon).square() + if requested >= flat_orig.numel(): + idx = torch.arange(flat_orig.numel(), dtype=torch.int32) + else: + idx = torch.topk(err, k=requested, largest=True, sorted=False).indices.to(dtype=torch.int64) + idx = idx.index_select(0, torch.argsort(idx)).to(dtype=torch.int32).cpu() + vals = flat_orig.index_select(0, idx.to(dtype=torch.int64)).cpu().contiguous() + max_abs = float(vals.abs().max().item()) if vals.numel() > 0 else 0.0 + if max_abs <= 1e-12: + scale = torch.ones((1,), dtype=torch.float16) + q = torch.zeros(vals.numel(), dtype=torch.int8) + else: + scale_value = max_abs / 127.0 + q = torch.round(vals / scale_value).clamp_(-127, 127).to(dtype=torch.int8).contiguous() + scale = torch.tensor([scale_value], dtype=torch.float16) + return idx.contiguous(), q.cpu().contiguous(), scale.cpu().contiguous() + + def build_export( + self, + state_dict: dict[str, Tensor], + ) -> tuple[dict[str, Tensor], dict[str, object], dict[str, object]]: + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + export_stats: dict[str, object] = {"tensors": {}} + total_payload_bytes = 0 + target_payload_bytes = 0 + outlier_payload_bytes = 0 + outlier_weights = 0 + int8_fallback_payload_bytes = 0 + int8_fallback_weights = 0 + passthrough_payload_bytes = 0 + passthrough_weights = 0 + target_weights = 0 + total_fp_weights = sum(int(t.numel()) for _, t in state_dict.items() if t.is_floating_point()) + + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + if name in self.states: + state = self.states[name] + idx = state["fixed_idx"] + scales = state["scales"] + scale_payload, scale_meta = _quantize_codebook_scales(scales, self.h.codebook_scale_bits) + result[name + ".idx"] = idx + result[name + ".scale"] = scale_payload + tensor_outlier_count = 0 + tensor_outlier_payload_bytes = 0 + if self.h.codebook_outlier_frac > 0 or self.h.codebook_outlier_max_count > 0: + reconstructed = self._reconstruct_tensor_from_state(name, state, dtype=torch.float32) + outlier_idx, outlier_q, outlier_scale = self._select_outliers(t.float(), reconstructed) + if ( + outlier_idx is not None + and outlier_q is not None + and outlier_scale is not None + and outlier_idx.numel() > 0 + ): + result[name + ".outlier_idx"] = outlier_idx + result[name + ".outlier_q"] = outlier_q + result[name + ".outlier_scale"] = outlier_scale + tensor_outlier_count = int(outlier_idx.numel()) + tensor_outlier_payload_bytes = ( + outlier_idx.numel() * outlier_idx.element_size() + + outlier_q.numel() * outlier_q.element_size() + + outlier_scale.numel() * outlier_scale.element_size() + ) + meta[name] = { + "type": "codebook_e8p", + "shape": list(state["shape"]), + "block_dim": int(state["block_dim"]), + "hadamard": bool(self.h.codebook_use_hadamard), + "lattice_scale": float(self.h.codebook_lattice_scale), + "outlier_count": tensor_outlier_count, + "outlier_format": "int8" if tensor_outlier_count > 0 else "none", + **scale_meta, + } + payload_bytes = ( + idx.numel() * idx.element_size() + + scale_payload.numel() * scale_payload.element_size() + + tensor_outlier_payload_bytes + ) + total_payload_bytes += payload_bytes + target_payload_bytes += payload_bytes + outlier_payload_bytes += tensor_outlier_payload_bytes + outlier_weights += tensor_outlier_count + target_weights += int(state["stats"]["num_weights"]) + export_stats["tensors"][name] = { + **state["stats"], + "payload_bytes": payload_bytes, + "outlier_payload_bytes": tensor_outlier_payload_bytes, + "outlier_count": tensor_outlier_count, + "scale_bits": int(scale_meta["scale_bits"]), + "scale_format": str(scale_meta["scale_format"]), + "bpw": (8.0 * payload_bytes) / max(float(state["stats"]["num_weights"]), 1.0), + } + export_stats.setdefault("diagnostics", []).append( + _codebook_payload_diagnostics( + name, + idx, + scales, + scale_payload, + tuple(int(x) for x in state["shape"]), + int(state["block_dim"]), + self.h.compressor, + ) + ) + continue + + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + payload_bytes = result[name].numel() * result[name].element_size() + total_payload_bytes += payload_bytes + passthrough_payload_bytes += payload_bytes + if t.is_floating_point(): + passthrough_weights += int(t.numel()) + continue + + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + payload_bytes = result[name].numel() * result[name].element_size() + total_payload_bytes += payload_bytes + passthrough_payload_bytes += payload_bytes + passthrough_weights += int(t.numel()) + continue + + q, s = quantize_float_tensor(t) + result[name + ".q"] = q.cpu().contiguous() + result[name + ".scale"] = s.cpu().contiguous() + meta[name] = {"type": "int8"} + payload_bytes = ( + result[name + ".q"].numel() * result[name + ".q"].element_size() + + result[name + ".scale"].numel() * result[name + ".scale"].element_size() + ) + total_payload_bytes += payload_bytes + int8_fallback_payload_bytes += payload_bytes + int8_fallback_weights += int(t.numel()) + + export_stats["summary"] = { + "target_tensors": len(self.states), + "target_weights": target_weights, + "coverage": target_weights / max(total_fp_weights, 1), + "target_payload_bytes": target_payload_bytes, + "target_bpw": (8.0 * target_payload_bytes) / max(target_weights, 1), + "outlier_weights": outlier_weights, + "outlier_payload_bytes": outlier_payload_bytes, + "int8_fallback_weights": int8_fallback_weights, + "int8_fallback_payload_bytes": int8_fallback_payload_bytes, + "passthrough_weights": passthrough_weights, + "passthrough_payload_bytes": passthrough_payload_bytes, + "payload_bytes_before_torchsave": total_payload_bytes, + "effective_payload_bpw_all_weights": (8.0 * total_payload_bytes) / max(total_fp_weights, 1), + "model_fp_weights": total_fp_weights, + } + return result, meta, export_stats + + +def dequantize_state_dict_codebook( + result: dict[str, Tensor], + meta: dict[str, object], + template_sd: dict[str, Tensor], +) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + if info in ("passthrough", "passthrough_ctrl"): + t = result[name] + if t.dtype == torch.float16 and orig.dtype in (torch.float32, torch.bfloat16): + t = t.to(orig.dtype) + out[name] = t + continue + if not isinstance(info, dict): + raise ValueError(f"Unsupported compression metadata for {name}: {info!r}") + if info.get("type") == "int8": + out[name] = dequantize_int8_rows( + result[name + ".q"], + result[name + ".scale"], + device=torch.device("cpu"), + dtype=orig.dtype, + ) + continue + if info.get("type") not in ("codebook_e8p_fp16_scale", "codebook_e8p"): + raise ValueError(f"Unsupported compression metadata for {name}: {info!r}") + shape = tuple(int(x) for x in info["shape"]) + block_dim = int(info["block_dim"]) + idx = result[name + ".idx"].to(dtype=torch.int64) + if info.get("type") == "codebook_e8p_fp16_scale": + scales = result[name + ".scale"].to(dtype=torch.float32).view(-1, 1) + else: + scales = _dequantize_codebook_scales(result[name + ".scale"], info) + fixed_blocks = _decode_e8p_blocks( + idx, + float(info["lattice_scale"]), + device=torch.device("cpu"), + dtype=torch.float32, + ) + rotated_blocks = fixed_blocks * scales + sign_vec = _hadamard_sign_vector(name, block_dim, device=torch.device("cpu"), dtype=torch.float32) + blocks = hadamard_unrotate_blocks( + rotated_blocks, + sign_vec, + enabled=bool(info.get("hadamard", True)), + ) + restored = _unblockify_weight(blocks, shape).to(orig.dtype) + if int(info.get("outlier_count", 0)) > 0: + flat = restored.reshape(-1) + outlier_idx = result[name + ".outlier_idx"].to(dtype=torch.int64) + outlier_q = result[name + ".outlier_q"].to(dtype=torch.float32).reshape(-1) + outlier_scale = float(result[name + ".outlier_scale"].to(dtype=torch.float32).reshape(-1)[0].item()) + outlier_val = (outlier_q * outlier_scale).to(dtype=orig.dtype) + flat[outlier_idx] = outlier_val + restored = flat.view_as(restored) + out[name] = restored + return out + + +_BSHF_MAGIC = b"BSHF" + + +def _byte_shuffle(data: bytes, stride: int = 2) -> bytes: + """Transpose byte stream by stride position for better compression.""" + if stride <= 1 or len(data) < stride: + return data + src = np.frombuffer(data, dtype=np.uint8) + n = len(src) + out = np.empty(n, dtype=np.uint8) + dest_off = 0 + for pos in range(stride): + chunk = src[pos::stride] + out[dest_off:dest_off + len(chunk)] = chunk + dest_off += len(chunk) + return _BSHF_MAGIC + bytes([stride]) + out.tobytes() + + +def _byte_unshuffle(data: bytes) -> bytes: + """Inverse of _byte_shuffle. Auto-detects BSHF magic header.""" + if len(data) < 5 or data[:4] != _BSHF_MAGIC: + return data + stride = data[4] + if stride < 2: + return data[5:] + payload = np.frombuffer(data, dtype=np.uint8, offset=5) + n = len(payload) + out = np.empty(n, dtype=np.uint8) + src_off = 0 + for pos in range(stride): + chunk_len = n // stride + (1 if pos < n % stride else 0) + out[pos::stride][:chunk_len] = payload[src_off:src_off + chunk_len] + src_off += chunk_len + return out.tobytes() + + +def _compress(data: bytes, compressor: str, byte_shuffle: bool = True) -> bytes: + if byte_shuffle: + data = _byte_shuffle(data) + if compressor == "lzma": + return lzma.compress(data, preset=6) + elif compressor == "brotli": + import brotli + return brotli.compress(data, quality=11) + raise ValueError(f"Unknown compressor: {compressor!r}") + + +def _decompress(data: bytes, compressor: str, byte_shuffle: bool = True) -> bytes: + if compressor == "lzma": + raw = lzma.decompress(data) + elif compressor == "brotli": + import brotli + raw = brotli.decompress(data) + else: + raise ValueError(f"Unknown compressor: {compressor!r}") + if byte_shuffle: + raw = _byte_unshuffle(raw) + return raw + + +def _codebook_scale_dtype(bits: int) -> torch.dtype: + return torch.uint8 if bits <= 8 else torch.uint16 + + +@torch.no_grad() +def _quantize_codebook_scales(scales: Tensor, bits: int) -> tuple[Tensor, dict[str, object]]: + scales_fp32 = scales.to(dtype=torch.float32).clamp_min(1e-12) + if bits >= 16: + return scales.to(dtype=torch.float16).contiguous(), { + "scale_format": "fp16", + "scale_bits": 16, + } + + levels = (1 << bits) - 1 + log_scales = torch.log2(scales_fp32) + log_min = float(log_scales.min().item()) + log_max = float(log_scales.max().item()) + if log_max <= log_min: + scale_codes = torch.zeros_like(scales_fp32, dtype=_codebook_scale_dtype(bits)) + log_step = 0.0 + else: + log_step = (log_max - log_min) / levels + scale_codes = torch.round((log_scales - log_min) / log_step).clamp_(0, levels).to(_codebook_scale_dtype(bits)) + return scale_codes.contiguous().cpu(), { + "scale_format": "log", + "scale_bits": int(bits), + "scale_log_min": log_min, + "scale_log_step": float(log_step), + } + + +@torch.no_grad() +def _dequantize_codebook_scales(scale_payload: Tensor, info: dict[str, object]) -> Tensor: + scale_format = str(info.get("scale_format", "fp16")) + if scale_format == "fp16": + return scale_payload.to(dtype=torch.float32).view(-1, 1) + if scale_format != "log": + raise ValueError(f"Unsupported codebook scale format: {scale_format!r}") + log_min = float(info["scale_log_min"]) + log_step = float(info["scale_log_step"]) + if log_step == 0.0: + return torch.full( + (scale_payload.numel(), 1), + 2.0 ** log_min, + dtype=torch.float32, + ) + logs = log_min + scale_payload.to(dtype=torch.float32).reshape(-1, 1) * log_step + return torch.pow(2.0, logs) + + + + +def _entropy_from_counts(counts: np.ndarray) -> float: + total = int(counts.sum()) + if total <= 0: + return 0.0 + probs = counts[counts > 0].astype(np.float64) / float(total) + return float(-(probs * np.log2(probs)).sum()) + + +def _entropy_from_u8_bytes(data: bytes) -> float: + if not data: + return 0.0 + arr = np.frombuffer(data, dtype=np.uint8) + counts = np.bincount(arr, minlength=256) + return _entropy_from_counts(counts) + + +def _adjacent_corrcoef(arr: np.ndarray) -> float: + if arr.size < 2: + return 1.0 + x0 = arr[:-1].astype(np.float64, copy=False) + x1 = arr[1:].astype(np.float64, copy=False) + x0 = x0 - x0.mean() + x1 = x1 - x1.mean() + denom = math.sqrt(float(np.dot(x0, x0) * np.dot(x1, x1))) + if denom <= 0.0: + return 1.0 + return float(np.dot(x0, x1) / denom) + + +def _codebook_payload_diagnostics( + name: str, + idx: Tensor, + scales: Tensor, + scale_payload: Tensor, + shape: tuple[int, int], + block_dim: int, + compressor: str, +) -> dict[str, float | int | str]: + idx_np = idx.contiguous().numpy().reshape(-1) + idx_codes = idx_np.astype(np.int64, copy=False) + scales_np = scales.contiguous().numpy().reshape(-1).astype(np.float16, copy=False) + scale_payload_np = scale_payload.contiguous().numpy().reshape(-1) + idx_bytes = idx_np.tobytes() + scale_bytes = scale_payload_np.tobytes() + combined_bytes = idx_bytes + scale_bytes + + idx_max = int(idx_codes.max()) if idx_codes.size else 0 + idx_counts = np.bincount(idx_codes, minlength=max(idx_max + 1, 1)) + idx_entropy_bits = _entropy_from_counts(idx_counts) + idx_unique = int(np.count_nonzero(idx_counts)) + + scale_words = scales_np.view(np.uint16) + _, scale_counts = np.unique(scale_words, return_counts=True) + scale_value_entropy_bits = _entropy_from_counts(scale_counts) + + scale_grid = scales_np.astype(np.float32, copy=False).reshape(shape[0], shape[1] // block_dim) + if scale_grid.shape[1] > 1: + scale_deltas = np.diff(scale_grid, axis=1).astype(np.float16, copy=False).reshape(-1) + scale_delta_words = scale_deltas.view(np.uint16) + _, delta_counts = np.unique(scale_delta_words, return_counts=True) + scale_delta_entropy_bits = _entropy_from_counts(delta_counts) + scale_delta_mean_abs = float(np.mean(np.abs(scale_deltas.astype(np.float32, copy=False)))) + scale_adjacent_corr = _adjacent_corrcoef(scale_grid.reshape(-1)) + scale_relative_delta = scale_delta_mean_abs / max(float(np.mean(np.abs(scale_grid))), 1e-12) + else: + scale_delta_entropy_bits = 0.0 + scale_delta_mean_abs = 0.0 + scale_adjacent_corr = 1.0 + scale_relative_delta = 0.0 + + idx_compressed_bytes = len(_compress(idx_bytes, compressor)) + scale_compressed_bytes = len(_compress(scale_bytes, compressor)) + combined_compressed_bytes = len(_compress(combined_bytes, compressor)) + idx_byte_entropy = _entropy_from_u8_bytes(_byte_shuffle(idx_bytes)) + scale_byte_entropy = _entropy_from_u8_bytes(_byte_shuffle(scale_bytes)) + + return { + "name": name, + "idx_unique": idx_unique, + "idx_entropy_bits": idx_entropy_bits, + "idx_byte_entropy_bits": idx_byte_entropy, + "idx_raw_bytes": len(idx_bytes), + "idx_compressed_bytes": idx_compressed_bytes, + "idx_entropy_bound_bytes": (idx_np.size * idx_entropy_bits) / 8.0, + "scale_unique": int(scale_counts.size), + "scale_value_entropy_bits": scale_value_entropy_bits, + "scale_delta_entropy_bits": scale_delta_entropy_bits, + "scale_byte_entropy_bits": scale_byte_entropy, + "scale_raw_bytes": len(scale_bytes), + "scale_compressed_bytes": scale_compressed_bytes, + "scale_entropy_bound_bytes": (scales_np.size * scale_value_entropy_bits) / 8.0, + "scale_delta_mean_abs": scale_delta_mean_abs, + "scale_relative_delta": scale_relative_delta, + "scale_adjacent_corr": scale_adjacent_corr, + "scale_storage_bits": float(8.0 * len(scale_bytes) / max(scale_payload_np.size, 1)), + "combined_raw_bytes": len(combined_bytes), + "combined_compressed_bytes": combined_compressed_bytes, + } + + +def _select_codebook_focus_tensor( + diagnostics: list[dict[str, float | int | str]], + focus: str, +) -> dict[str, float | int | str] | None: + if not diagnostics: + return None + if focus: + for item in diagnostics: + if item["name"] == focus: + return item + for item in diagnostics: + if focus in str(item["name"]): + return item + return max(diagnostics, key=lambda item: float(item["scale_compressed_bytes"])) + + +def _log_codebook_diagnostics(h: Hyperparameters, quant_stats: dict[str, object]) -> None: + diagnostics = list(quant_stats.get("diagnostics", [])) + if not diagnostics: + return + topk = max(int(h.codebook_entropy_summary_topk), 0) + total_idx_compressed = sum(int(item["idx_compressed_bytes"]) for item in diagnostics) + total_scale_compressed = sum(int(item["scale_compressed_bytes"]) for item in diagnostics) + total_combined_compressed = sum(int(item["combined_compressed_bytes"]) for item in diagnostics) + log( + f"codebook:entropy_summary tensors:{len(diagnostics)} " + f"idx_compressed_bytes:{total_idx_compressed} " + f"scale_compressed_bytes:{total_scale_compressed} " + f"combined_compressed_bytes:{total_combined_compressed}" + ) + if topk > 0: + ranked = sorted( + diagnostics, + key=lambda item: (float(item["scale_compressed_bytes"]), float(item["combined_compressed_bytes"])), + reverse=True, + )[:topk] + for item in ranked: + log( + "codebook:tensor " + f"name:{item['name']} " + f"idx_raw:{item['idx_raw_bytes']} idx_zip:{item['idx_compressed_bytes']} " + f"scale_raw:{item['scale_raw_bytes']} scale_zip:{item['scale_compressed_bytes']} " + f"combined_zip:{item['combined_compressed_bytes']} " + f"idx_H:{float(item['idx_entropy_bits']):.3f}bits " + f"scale_H:{float(item['scale_value_entropy_bits']):.3f}bits " + f"scale_dH:{float(item['scale_delta_entropy_bits']):.3f}bits " + f"scale_store:{float(item['scale_storage_bits']):.1f}bits " + f"scale_rel_delta:{float(item['scale_relative_delta']):.5f} " + f"scale_corr:{float(item['scale_adjacent_corr']):.5f}" + ) + focus_item = _select_codebook_focus_tensor(diagnostics, h.codebook_entropy_focus) + if focus_item is None: + return + log( + "codebook:focus " + f"name:{focus_item['name']} " + f"idx_unique:{focus_item['idx_unique']} " + f"idx_H:{float(focus_item['idx_entropy_bits']):.4f}bits " + f"idx_byte_H:{float(focus_item['idx_byte_entropy_bits']):.4f}bits " + f"idx_raw:{focus_item['idx_raw_bytes']} " + f"idx_entropy_bound:{float(focus_item['idx_entropy_bound_bytes']):.1f} " + f"idx_zip:{focus_item['idx_compressed_bytes']}" + ) + log( + "codebook:focus " + f"name:{focus_item['name']} " + f"scale_unique:{focus_item['scale_unique']} " + f"scale_H:{float(focus_item['scale_value_entropy_bits']):.4f}bits " + f"scale_dH:{float(focus_item['scale_delta_entropy_bits']):.4f}bits " + f"scale_byte_H:{float(focus_item['scale_byte_entropy_bits']):.4f}bits " + f"scale_store:{float(focus_item['scale_storage_bits']):.1f}bits " + f"scale_raw:{focus_item['scale_raw_bytes']} " + f"scale_entropy_bound:{float(focus_item['scale_entropy_bound_bytes']):.1f} " + f"scale_zip:{focus_item['scale_compressed_bytes']}" + ) + log( + "codebook:focus " + f"name:{focus_item['name']} " + f"combined_zip:{focus_item['combined_compressed_bytes']} " + f"scale_delta_mean_abs:{float(focus_item['scale_delta_mean_abs']):.6f} " + f"scale_rel_delta:{float(focus_item['scale_relative_delta']):.6f} " + f"scale_adjacent_corr:{float(focus_item['scale_adjacent_corr']):.6f}" + ) + + +def serialize(h: Hyperparameters, base_model: torch.nn.Module, code: str) -> int: + quantizer = CodebookQuantizer(h, base_model) + device = next(base_model.parameters()).device + code_bytes = len(code.encode("utf-8")) + bytes_total = code_bytes + if h.is_main_process: + torch.save(base_model.state_dict(), h.model_path) + model_bytes = os.path.getsize(h.model_path) + log(f"Serialized model: {model_bytes} bytes") + log(f"Code size: {code_bytes} bytes") + + hessians: dict[str, Tensor] = {} + if quantizer.target_names: + if h.is_main_process: + log("codebook:collecting calibration hessians...") + t0 = time.perf_counter() + calib_loader = DistributedTokenLoader(h.train_files, h.rank, h.world_size, device) + hessians = collect_hessians( + base_model, + calib_loader, + h, + device, + quantizer.target_names, + n_calibration_batches=h.codebook_calibration_batches, + ) + if h.is_main_process: + log(f"codebook:collected {len(hessians)} hessians in {time.perf_counter() - t0:.1f}s") + + if h.is_main_process: + t0 = time.perf_counter() + quantizer.fit(base_model, hessians) + log(f"codebook:fit_time:{time.perf_counter() - t0:.1f}s") + quant_result, quant_meta, quant_stats = quantizer.build_export(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta, "s": quant_stats}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = _compress(quant_raw, h.compressor) + quant_file_bytes = len(quant_blob) + bytes_total = quant_file_bytes + code_bytes + with open(h.quantized_model_path, "wb") as f: + f.write(quant_blob) + summary = quant_stats["summary"] + log( + f"Serialized model codebook+{h.compressor}: {quant_file_bytes} bytes " + f"(payload_before_torchsave:{summary['payload_bytes_before_torchsave']} bytes)" + ) + log( + f"Codebook coverage:{summary['coverage']:.4%} target_bpw:{summary['target_bpw']:.4f} " + f"effective_payload_bpw_all_weights:{summary['effective_payload_bpw_all_weights']:.4f}" + ) + log( + f"Codebook scale_codec:{'fp16' if h.codebook_scale_bits >= 16 else 'log'} " + f"scale_bits:{h.codebook_scale_bits} " + f"entropy_weight:{h.codebook_assignment_entropy_weight:.4f}" + ) + log( + f"Codebook outliers frac:{h.codebook_outlier_frac:.6f} " + f"max_count:{h.codebook_outlier_max_count} " + f"outlier_weights:{summary['outlier_weights']} " + f"outlier_bytes:{summary['outlier_payload_bytes']}" + ) + log( + f"Fallback payloads int8_weights:{summary['int8_fallback_weights']} " + f"int8_bytes:{summary['int8_fallback_payload_bytes']} " + f"passthrough_weights:{summary['passthrough_weights']} " + f"passthrough_bytes:{summary['passthrough_payload_bytes']}" + ) + _log_codebook_diagnostics(h, quant_stats) + log(f"Total submission size codebook+{h.compressor}: {bytes_total} bytes") + return bytes_total + + +def deserialize(h: Hyperparameters, device: torch.device) -> GPT: + eval_model = GPT(h).to(device).bfloat16() + restore_fp32_params(eval_model) + + template_sd = eval_model.state_dict() + + with open(h.quantized_model_path, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(_decompress(quant_blob_disk, h.compressor)), + map_location="cpu", + ) + deq_state = dequantize_state_dict_codebook(quant_state["w"], quant_state["m"], template_sd) + eval_model.load_state_dict(deq_state, strict=True) + + return eval_model + + +# ---------------------------------------- +# Evaluation +# ---------------------------------------- + +def _loss_bpb(loss_sum, token_count, byte_count) -> tuple[float, float]: + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + return val_loss, val_bpb + + +def eval_val( + h: Hyperparameters, + device: torch.device, + val_data: ValidationData, + model: nn.Module +) -> tuple[float, float]: + seq_len = h.eval_seq_len + local_batch_tokens = h.val_batch_tokens // (h.world_size * h.grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={h.val_batch_tokens}, WORLD_SIZE={h.world_size}, " + f"GRAD_ACCUM_STEPS={h.grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_data.val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * h.rank) // h.world_size + seq_end = (total_seqs * (h.rank + 1)) // h.world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + # Avoid leaving inference-mode tensors in module caches that training later reuses. + with torch.no_grad(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_data.val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = val_data.base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (val_data.has_leading_space_lut[tgt_ids] & ~val_data.is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + model.train() + return _loss_bpb(val_loss_sum, val_token_count, val_byte_count) + + +def eval_val_sliding( + h: Hyperparameters, + device: torch.device, + val_data: ValidationData, + base_model: nn.Module, + batch_seqs: int = 32 +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + base_model.eval() + logits_fn = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + + seq_len = h.eval_seq_len + context_size = seq_len - h.eval_stride + total_tokens = val_data.val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, h.eval_stride) + if ws + context_size < total_tokens] + + total_windows = len(window_starts) + my_s = (total_windows * h.rank) // h.world_size + my_e = (total_windows * (h.rank + 1)) // h.world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Avoid leaving inference-mode tensors in module caches that training later reuses. + with torch.no_grad(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + we = min(ws + seq_len, total_tokens) + wlen = we - ws + wlens.append(wlen) + chunk = val_data.val_tokens[ws:we + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = logits_fn(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else context_size + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = val_data.base_bytes_lut[tgt].to(torch.float64) + tb += (val_data.has_leading_space_lut[tgt] & ~val_data.is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + base_model.train() + return _loss_bpb(loss_sum, token_count, byte_count) + + +def timed_eval(label: str, fn, *args, **kwargs) -> tuple[float, float]: + torch.cuda.synchronize() + t0 = time.perf_counter() + val_loss, val_bpb = fn(*args, **kwargs) + torch.cuda.synchronize() + elapsed_ms = 1000.0 * (time.perf_counter() - t0) + log(f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms") + return val_loss, val_bpb + + +def run_evals( + h: Hyperparameters, + device: torch.device, + val_data: ValidationData, + eval_model: torch.nn.Module +): + compiled_model = torch.compile(eval_model, dynamic=False, fullgraph=True) + timed_eval("final_codebook_roundtrip", eval_val, h, device, val_data, compiled_model) + if h.sliding_window_enabled: + timed_eval("final_codebook_sliding_window", eval_val_sliding, h, device, val_data, eval_model) + +# ----------------------------- +# Training +# ----------------------------- + +def train_model(h: Hyperparameters, device: torch.device, val_data: ValidationData) -> None: + # Set up model + base_model = GPT(h).to(device).bfloat16() + restore_fp32_params(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + if h.distributed: + model = DDP(compiled_model, device_ids=[h.local_rank], broadcast_buffers=False) + else: + model = compiled_model + log(f"model_params:{sum(p.numel() for p in base_model.parameters())}") + + # Set up optimizer and load train data + optimizers = Optimizers(h, base_model) + train_loader = DistributedTokenLoader( h.train_files, h.rank, h.world_size, device) + soft_snap = FixedCodebookSoftSnap(h, base_model) if h.codebook_soft_snap_enabled else None + codebook_penalty = FixedCodebookPenalty(h, base_model) if h.codebook_penalty_enabled else None + + # Helper functions for training + max_wallclock_ms = 1000.0 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None + if max_wallclock_ms is not None and h.codebook_reserve_seconds > 0: + max_wallclock_ms = max(max_wallclock_ms - h.codebook_reserve_seconds * 1000.0, 0.0) + log(f"codebook:reserving {h.codebook_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms") + + def training_frac(step: int, elapsed_ms: float) -> float: + """Fraction of training completed (0 to 1), using step or wallclock.""" + if max_wallclock_ms is None: + return step / max(h.iterations, 1) + return elapsed_ms / max(max_wallclock_ms, 1e-9) + + def lr_mul(frac: float) -> float: + if h.warmdown_frac <= 0: + return 1.0 + if frac >= 1.0 - h.warmdown_frac: + return max((1.0 - frac) / h.warmdown_frac, h.min_lr) + return 1.0 + + def step_fn(step, lr_scale, frac): + optimizers.zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(h.grad_accum_steps): + if h.distributed: + model.require_backward_grad_sync = micro_step == h.grad_accum_steps - 1 + x, y = train_loader.next_batch(h.train_batch_tokens, h.train_seq_len, h.grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss / h.grad_accum_steps).backward() + train_loss /= h.grad_accum_steps + + momentum_frac = min(step / h.muon_momentum_warmup_steps, 1.0) if h.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - momentum_frac) * h.muon_momentum_warmup_start + momentum_frac * h.muon_momentum + for group in optimizers.optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * lr_scale + + codebook_penalty_loss = None + if codebook_penalty is not None: + codebook_penalty_loss = codebook_penalty.penalty(step, frac) + if codebook_penalty_loss is not None: + codebook_penalty_loss.backward() + + if h.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), h.grad_clip_norm) + + optimizers.step() + if soft_snap is not None: + soft_snap.apply(step + 1) + return train_loss, codebook_penalty_loss + + # Model warmup + if h.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0, 0.0) + if warmup_step <= 5 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == h.warmup_steps: + log(f"warmup_step: {warmup_step + 1}/{h.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + optimizers.zero_grad_all() + if h.distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader( + h.train_files, h.rank, h.world_size, device) + + # Training loop + use_ema = h.ema_decay < 1.0 + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} if use_ema else None + ema_decay = h.ema_decay + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == h.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (h.val_loss_every > 0 and step % h.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val(h, device, val_data, model) + log(f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}") + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < h.iterations: + log( + f"stopping_early: wallclock_cap train_time: {training_time_ms:.0f}ms " + f"step: {step}/{h.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + frac = training_frac(step, elapsed_ms) + if soft_snap is not None and soft_snap.should_activate(frac): + soft_snap.activate(step, frac) + if codebook_penalty is not None and codebook_penalty.should_activate(frac): + codebook_penalty.activate(step, frac) + scale = lr_mul(frac) + train_loss, codebook_penalty_loss = step_fn(step, scale, frac) + + if use_ema: + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + should_log_train = ( + h.train_log_every > 0 + and (step <= 5 or step % h.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + tok_per_sec = step * h.train_batch_tokens / (approx_training_time_ms / 1000.0) + train_loss_value = train_loss.item() + total_loss_value = train_loss_value + codebook_penalty_suffix = "" + if codebook_penalty_loss is not None: + codebook_penalty_loss_value = codebook_penalty_loss.item() + total_loss_value += codebook_penalty_loss_value + codebook_penalty_suffix = f" codebook_penalty_loss: {codebook_penalty_loss_value:.5f}" + log( + f"{step}/{h.iterations} train_loss: {train_loss_value:.4f} total_loss: {total_loss_value:.4f} " + f"train_time: {approx_training_time_ms / 60000:.1f}m tok/s: {tok_per_sec:.0f}" + f"{codebook_penalty_suffix}" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if h.distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + if use_ema: + log("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + + return base_model, compiled_model + + +def train_and_eval(h: Hyperparameters, device: torch.device) -> None: + random.seed(h.seed) + np.random.seed(h.seed) + torch.manual_seed(h.seed) + torch.cuda.manual_seed_all(h.seed) + + val_data = ValidationData(h, device) + log(f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob('fineweb_train_*.bin')))}") + log(f"val_tokens: {val_data.val_tokens.numel() - 1}") + + base_model, compiled_model = train_model(h, device, val_data) + timed_eval("pre-codebook post-ema" if h.ema_decay < 1.0 else "pre-codebook", eval_val, h, device, val_data, compiled_model) + + serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) + if h.distributed: + dist.barrier() + eval_model = deserialize(h, device) + + run_evals(h, device, val_data, eval_model) + + +def main(): + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision("high") + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + torch._dynamo.config.optimize_ddp = False + + h = Hyperparameters() + set_logging_hparams(h) + if h.is_main_process: + os.makedirs("logs", exist_ok=True) + log(100 * "=", console=False) + log("Hyperparameters:", console=True) + for k, v in sorted(vars(type(h)).items()): + if not k.startswith("_"): + log(f" {k}: {v}", console=True) + log(Path(__file__).read_text(encoding="utf-8"), console=False) + log("=" * 100, console=False) + log(f"Running Python {sys.version}", console=False) + log(f"Running PyTorch {torch.__version__}", console=False) + log( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log("=" * 100, console=False) + + train_and_eval(h, device) + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Tue Apr 7 06:04:39 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 43C P0 125W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 37C P0 122W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 35C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 40C P0 121W / 700W | 1521MiB / 81559MiB | 2% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 43C P0 123W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 36C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 42C P0 123W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 34C P0 119W / 700W | 1521MiB / 81559MiB | 3% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +train_shards: 80 +val_tokens: 44847104 +model_params:40173675 +codebook:reserving 30s, effective=570000ms +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +0/20000 val_loss: 8.3164 val_bpb: 3.5693 +1/20000 train_loss: 8.3183 total_loss: 8.3183 train_time: 0.0m tok/s: 7185058 +2/20000 train_loss: 12.3459 total_loss: 12.3459 train_time: 0.0m tok/s: 7071602 +3/20000 train_loss: 10.8912 total_loss: 10.8912 train_time: 0.0m tok/s: 7016360 +4/20000 train_loss: 9.1354 total_loss: 9.1354 train_time: 0.0m tok/s: 6984574 +5/20000 train_loss: 7.9185 total_loss: 7.9185 train_time: 0.0m tok/s: 6971996 +500/20000 train_loss: 3.0399 total_loss: 3.0399 train_time: 1.0m tok/s: 6784688 +1000/20000 train_loss: 2.8926 total_loss: 2.8926 train_time: 1.9m tok/s: 6777009 +1500/20000 train_loss: 2.8818 total_loss: 2.8818 train_time: 2.9m tok/s: 6777745 +2000/20000 train_loss: 2.7798 total_loss: 2.7798 train_time: 3.9m tok/s: 6779562 +2500/20000 train_loss: 2.8056 total_loss: 2.8056 train_time: 4.8m tok/s: 6780161 +3000/20000 train_loss: 2.6894 total_loss: 2.6894 train_time: 5.8m tok/s: 6782420 +3500/20000 train_loss: 2.6598 total_loss: 2.6598 train_time: 6.8m tok/s: 6783794 +4000/20000 train_loss: 2.6499 total_loss: 2.6499 train_time: 7.7m tok/s: 6784804 +4000/20000 val_loss: 2.6398 val_bpb: 1.1330 +codebook_penalty:enabled step:4427/20000 start_frac:0.900 frac:0.900 weight:4.00000 mode:fast_direction_target refresh_every:64 target_tensors:78 target_weights:37486592 +4500/20000 train_loss: 2.6050 total_loss: 2.7963 train_time: 8.8m tok/s: 6724505 codebook_penalty_loss: 0.19133 +4826/20000 val_loss: 2.5748 val_bpb: 1.1051 +stopping_early: wallclock_cap train_time: 570080ms step: 4826/20000 +peak memory allocated: 30446 MiB reserved: 31150 MiB +ema:applying EMA weights +pre-codebook post-ema val_loss:2.57227404 val_bpb:1.10397736 eval_time:1977ms +Serialized model: 155502379 bytes +Code size: 121768 bytes +codebook:collecting calibration hessians... +codebook:collected 78 hessians in 9.7s +codebook:fit target_tensors:78 target_weights:37486592 coverage:93.3113% rel_mse:3.508869e-02 target_bpw:4.0000 +codebook:fit_time:17.6s +Serialized model codebook+brotli: 15759400 bytes (payload_before_torchsave:17625458 bytes) +Codebook coverage:93.3113% target_bpw:3.1705 effective_payload_bpw_all_weights:3.5099 +Codebook scale_codec:log scale_bits:8 entropy_weight:0.0000 +Codebook outliers frac:0.000000 max_count:2048 outlier_weights:159744 outlier_bytes:798876 +Fallback payloads int8_weights:2621440 int8_bytes:2637824 passthrough_weights:65643 passthrough_bytes:131286 +codebook:entropy_summary tensors:78 idx_compressed_bytes:9371855 scale_compressed_bytes:3600453 combined_compressed_bytes:12961053 +codebook:tensor name:blocks.0.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:108822 combined_zip:370750 idx_H:15.459bits scale_H:10.751bits scale_dH:11.633bits scale_store:8.0bits scale_rel_delta:0.31645 scale_corr:-0.00083 +codebook:tensor name:blocks.5.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:108359 combined_zip:370186 idx_H:15.376bits scale_H:10.863bits scale_dH:11.772bits scale_store:8.0bits scale_rel_delta:0.34223 scale_corr:-0.00699 +codebook:tensor name:blocks.4.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:107807 combined_zip:369644 idx_H:15.421bits scale_H:10.813bits scale_dH:11.708bits scale_store:8.0bits scale_rel_delta:0.32899 scale_corr:-0.00114 +codebook:tensor name:blocks.2.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:107583 combined_zip:369373 idx_H:15.434bits scale_H:10.780bits scale_dH:11.664bits scale_store:8.0bits scale_rel_delta:0.32210 scale_corr:0.00147 +codebook:tensor name:blocks.7.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:107207 combined_zip:369081 idx_H:15.382bits scale_H:10.869bits scale_dH:11.773bits scale_store:8.0bits scale_rel_delta:0.34210 scale_corr:-0.00111 +codebook:tensor name:blocks.1.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:107145 combined_zip:368994 idx_H:15.442bits scale_H:10.764bits scale_dH:11.648bits scale_store:8.0bits scale_rel_delta:0.31888 scale_corr:0.00212 +codebook:tensor name:blocks.3.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:107069 combined_zip:368867 idx_H:15.423bits scale_H:10.785bits scale_dH:11.682bits scale_store:8.0bits scale_rel_delta:0.32466 scale_corr:-0.00273 +codebook:tensor name:blocks.9.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:106636 combined_zip:368587 idx_H:15.457bits scale_H:10.812bits scale_dH:11.720bits scale_store:8.0bits scale_rel_delta:0.33073 scale_corr:-0.00217 +codebook:tensor name:blocks.10.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:106247 combined_zip:368194 idx_H:15.455bits scale_H:10.799bits scale_dH:11.699bits scale_store:8.0bits scale_rel_delta:0.32687 scale_corr:0.00201 +codebook:tensor name:blocks.8.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:105866 combined_zip:367753 idx_H:15.419bits scale_H:10.823bits scale_dH:11.730bits scale_store:8.0bits scale_rel_delta:0.33372 scale_corr:-0.00677 +codebook:tensor name:blocks.11.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:105576 combined_zip:367504 idx_H:15.462bits scale_H:10.765bits scale_dH:11.658bits scale_store:8.0bits scale_rel_delta:0.32028 scale_corr:-0.00184 +codebook:tensor name:blocks.6.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:104807 combined_zip:366524 idx_H:15.329bits scale_H:10.902bits scale_dH:11.816bits scale_store:8.0bits scale_rel_delta:0.35111 scale_corr:-0.00836 +codebook:focus name:blocks.0.mlp.proj.weight idx_unique:53062 idx_H:15.4589bits idx_byte_H:7.9824bits idx_raw:262144 idx_entropy_bound:253279.0 idx_zip:262154 +codebook:focus name:blocks.0.mlp.proj.weight scale_unique:2955 scale_H:10.7514bits scale_dH:11.6334bits scale_byte_H:6.6177bits scale_store:8.0bits scale_raw:131072 scale_entropy_bound:176150.2 scale_zip:108822 +codebook:focus name:blocks.0.mlp.proj.weight combined_zip:370750 scale_delta_mean_abs:0.016324 scale_rel_delta:0.316450 scale_adjacent_corr:-0.000832 +Total submission size codebook+brotli: 15881168 bytes +final_codebook_roundtrip val_loss:2.84121342 val_bpb:1.21940168 eval_time:6900ms +final_codebook_sliding_window val_loss:2.80085242 val_bpb:1.20207941 eval_time:76510ms diff --git a/records/track_non_record_16mb/2026-04-07_Codebooks/vq_codebook_13_1337.log b/records/track_non_record_16mb/2026-04-07_Codebooks/vq_codebook_13_1337.log new file mode 100644 index 0000000000..423bc815ca --- /dev/null +++ b/records/track_non_record_16mb/2026-04-07_Codebooks/vq_codebook_13_1337.log @@ -0,0 +1,2998 @@ +==================================================================================================== +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + codebook_assignment_entropy_weight: 0.0 + codebook_block_dim: 8 + codebook_calibration_batches: 64 + codebook_entropy_focus: + codebook_entropy_summary_topk: 12 + codebook_hessian_damp: 0.01 + codebook_lattice_scale: 1.03 + codebook_outlier_frac: 0.0 + codebook_outlier_max_count: 2048 + codebook_penalty_enabled: True + codebook_penalty_refresh_every: 64 + codebook_penalty_start_frac: 0.9 + codebook_penalty_weight: 4.0 + codebook_reserve_seconds: 30.0 + codebook_scale_bits: 8 + codebook_soft_snap_alpha: 0.02 + codebook_soft_snap_enabled: False + codebook_soft_snap_start_frac: 0.9 + codebook_soft_snap_update_every: 1 + codebook_use_hadamard: True + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp4096 + distributed: True + ema_decay: 0.99 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + eval_seq_len: 2048 + eval_stride: 64 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/9ce39490-3c2e-4a3c-9a2e-e2fb765ce3df.txt + logit_softcap: 30.0 + matrix_lr: 0.02 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_wd: 0.085 + num_heads: 8 + num_kv_heads: 4 + num_layers: 13 + qk_gain_init: 4.0 + quantized_model_path: final_model.codebook.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: 9ce39490-3c2e-4a3c-9a2e-e2fb765ce3df + scalar_lr: 0.02 + seed: 1337 + skip_gates_enabled: True + sliding_window_enabled: True + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_4096_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp4096/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + val_batch_tokens: 524288 + val_files: ./data/datasets/fineweb10B_sp4096/fineweb_val_*.bin + val_loss_every: 4000 + ve_dim: 128 + ve_enabled: True + ve_layers: 9,10 + vocab_size: 4096 + warmdown_frac: 0.667 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 13 +import copy +import glob +import io +import lzma +import math +import os +from pathlib import Path +import random +import subprocess +import sys +import time +import uuid + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +from torch import Tensor, nn + +try: + from flash_attn_interface import flash_attn_func as _flash_attn_3_func +except ImportError: + _flash_attn_3_func = None + +def _use_flash_attn() -> bool: + if _flash_attn_3_func is None: + return False + v = os.environ.get("USE_FLASH_ATTN", "1").strip().lower() + return v not in ("0", "false", "no", "off") + +# ---------------------------------------- +# Hyperparameters +# ---------------------------------------- + +class Hyperparameters(): + # Experiment settings + data_dir = os.environ.get('DATA_DIR', './data/') + seed = int(os.environ.get('SEED', 1337)) + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + + # Training length + iterations = int(os.environ.get('ITERATIONS', 20000)) + warmdown_frac = float(os.environ.get('WARMDOWN_FRAC', 0.667)) + warmup_steps = int(os.environ.get('WARMUP_STEPS', 20)) + train_batch_tokens = int(os.environ.get('TRAIN_BATCH_TOKENS', 2048 * 48 * 8)) + train_seq_len = int(os.environ.get('TRAIN_SEQ_LEN', 2048)) + eval_seq_len = int(os.environ.get('EVAL_SEQ_LEN', 2048)) + max_wallclock_seconds = float(os.environ.get('MAX_WALLCLOCK_SECONDS', 600.0)) + train_log_every = int(os.environ.get('TRAIN_LOG_EVERY', 500)) + + # Validation/Evals + val_batch_tokens = int(os.environ.get('VAL_BATCH_TOKENS', 2048 * 32 * 8)) + val_loss_every = int(os.environ.get('VAL_LOSS_EVERY', 4000)) + sliding_window_enabled = bool(int(os.environ.get('SLIDING_WINDOW_ENABLED', '1'))) + + # Model architecture + vocab_size = int(os.environ.get('VOCAB_SIZE', 4096)) + num_layers = int(os.environ.get('NUM_LAYERS', 11)) + xsa_last_n = int(os.environ.get('XSA_LAST_N', 11)) + num_kv_heads = int(os.environ.get('NUM_KV_HEADS', 4)) + model_dim = int(os.environ.get('MODEL_DIM', 512)) + embedding_dim = int(os.environ.get('EMBEDDING_DIM', 512)) + num_heads = int(os.environ.get('NUM_HEADS', 8)) + mlp_mult = float(os.environ.get('MLP_MULT', 4.0)) + skip_gates_enabled = bool(int(os.environ.get('SKIP_GATES_ENABLED', '1'))) + tie_embeddings = bool(int(os.environ.get('TIE_EMBEDDINGS', '1'))) + logit_softcap = float(os.environ.get('LOGIT_SOFTCAP', 30.0)) + rope_base = float(os.environ.get('ROPE_BASE', 10000.0)) + rope_dims = int(os.environ.get('ROPE_DIMS', 16)) + rope_train_seq_len = int(os.environ.get('ROPE_TRAIN_SEQ_LEN', 2048)) + ln_scale = bool(int(os.environ.get('LN_SCALE', '1'))) + ve_enabled = bool(int(os.environ.get('VE_ENABLED', '1'))) + ve_dim = int(os.environ.get('VE_DIM', 128)) + ve_layers = os.environ.get('VE_LAYERS', '9,10') + qk_gain_init = float(os.environ.get('QK_GAIN_INIT', 4.0)) + + # Optimizer + min_lr = float(os.environ.get('MIN_LR', 0.0)) + embed_lr = float(os.environ.get('EMBED_LR', 0.6)) + head_lr = float(os.environ.get('HEAD_LR', 0.008)) + tied_embed_lr = float(os.environ.get('TIED_EMBED_LR', 0.03)) + tied_embed_init_std = float(os.environ.get('TIED_EMBED_INIT_STD', 0.005)) + matrix_lr = float(os.environ.get('MATRIX_LR', 0.02)) + scalar_lr = float(os.environ.get('SCALAR_LR', 0.02)) + muon_momentum = float(os.environ.get('MUON_MOMENTUM', 0.99)) + muon_backend_steps = int(os.environ.get('MUON_BACKEND_STEPS', 5)) + muon_momentum_warmup_start = float(os.environ.get('MUON_MOMENTUM_WARMUP_START', 0.92)) + muon_momentum_warmup_steps = int(os.environ.get('MUON_MOMENTUM_WARMUP_STEPS', 1500)) + beta1 = float(os.environ.get('BETA1', 0.9)) + beta2 = float(os.environ.get('BETA2', 0.95)) + adam_eps = float(os.environ.get('ADAM_EPS', 1e-8)) + grad_clip_norm = float(os.environ.get('GRAD_CLIP_NORM', 0.3)) + eval_stride = int(os.environ.get('EVAL_STRIDE', 64)) + muon_beta2 = float(os.environ.get('MUON_BETA2', 0.95)) + adam_wd = float(os.environ.get('ADAM_WD', 0.02)) + muon_wd = float(os.environ.get('MUON_WD', 0.085)) + embed_wd = float(os.environ.get('EMBED_WD', 0.085)) + ema_decay = float(os.environ.get('EMA_DECAY', 1.0)) + + # Compression + compressor = os.environ.get('COMPRESSOR', 'brotli') #(lzma or brotli) + codebook_block_dim = int(os.environ.get("CODEBOOK_BLOCK_DIM", 8)) + codebook_use_hadamard = bool(int(os.environ.get("CODEBOOK_USE_HADAMARD", "1"))) + codebook_calibration_batches = int(os.environ.get("CODEBOOK_CALIBRATION_BATCHES", 64)) + codebook_reserve_seconds = float(os.environ.get("CODEBOOK_RESERVE_SECONDS", 10.0)) + codebook_hessian_damp = float(os.environ.get("CODEBOOK_HESSIAN_DAMP", 0.01)) + codebook_lattice_scale = float(os.environ.get("CODEBOOK_LATTICE_SCALE", 1.03)) + codebook_scale_bits = int(os.environ.get("CODEBOOK_SCALE_BITS", 8)) + codebook_assignment_entropy_weight = float(os.environ.get("CODEBOOK_ASSIGNMENT_ENTROPY_WEIGHT", 0.01)) + codebook_soft_snap_enabled = bool(int(os.environ.get("CODEBOOK_SOFT_SNAP_ENABLED", "0"))) + codebook_soft_snap_start_frac = float(os.environ.get("CODEBOOK_SOFT_SNAP_START_FRAC", 0.9)) + codebook_soft_snap_alpha = float(os.environ.get("CODEBOOK_SOFT_SNAP_ALPHA", 0.02)) + codebook_soft_snap_update_every = int(os.environ.get("CODEBOOK_SOFT_SNAP_UPDATE_EVERY", 1)) + codebook_penalty_enabled = bool( + int(os.environ.get("CODEBOOK_PENALTY_ENABLED", os.environ.get("CODEBOOK_STALE_PENALTY_ENABLED", "0"))) + ) + codebook_penalty_start_frac = float( + os.environ.get("CODEBOOK_PENALTY_START_FRAC", os.environ.get("CODEBOOK_STALE_PENALTY_START_FRAC", 0.75)) + ) + codebook_penalty_weight = float( + os.environ.get("CODEBOOK_PENALTY_WEIGHT", os.environ.get("CODEBOOK_STALE_PENALTY_WEIGHT", 0.01)) + ) + codebook_penalty_refresh_every = int( + os.environ.get("CODEBOOK_PENALTY_REFRESH_EVERY", os.environ.get("CODEBOOK_STALE_PENALTY_REFRESH_EVERY", 16)) + ) + codebook_outlier_frac = float(os.environ.get("CODEBOOK_OUTLIER_FRAC", 0.0)) + codebook_outlier_max_count = int(os.environ.get("CODEBOOK_OUTLIER_MAX_COUNT", 0)) + codebook_entropy_summary_topk = int(os.environ.get("CODEBOOK_ENTROPY_SUMMARY_TOPK", 12)) + codebook_entropy_focus = os.environ.get("CODEBOOK_ENTROPY_FOCUS", "").strip() + + # Distributed setup + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + is_main_process = rank == 0 + grad_accum_steps = 8 // world_size + + # Data paths + datasets_dir = os.path.join(data_dir, 'datasets', f'fineweb10B_sp{vocab_size}') + train_files = os.path.join(datasets_dir, 'fineweb_train_*.bin') + val_files = os.path.join(datasets_dir, 'fineweb_val_*.bin') + tokenizer_path = os.path.join(data_dir, 'tokenizers', f'fineweb_{vocab_size}_bpe.model') + + # Experiment files + logfile = f"logs/{run_id}.txt" + model_path = "final_model.pt" + quantized_model_path = "final_model.codebook.ptz" + +# ---------------------------------------- +# Global Logging Function +# ---------------------------------------- + +_logger_hparams = None + + +def set_logging_hparams(h: Hyperparameters) -> None: + global _logger_hparams + _logger_hparams = h + + +def log(msg, console: bool = True) -> None: + if _logger_hparams is None: + print(msg) + if _logger_hparams.is_main_process: + if console: + print(msg) + if _logger_hparams.logfile is not None: + with open(_logger_hparams.logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + +# ---------------------------------------- +# Data Loading +# ---------------------------------------- + +class ValidationData: + def __init__(self, h: Hyperparameters, device: torch.device): + if not h.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {h.tokenizer_path}") + self.sp = spm.SentencePieceProcessor(model_file=h.tokenizer_path) + if int(self.sp.vocab_size()) != h.vocab_size: + raise ValueError( + f"VOCAB_SIZE={h.vocab_size} does not match tokenizer vocab_size={int(self.sp.vocab_size())}" + ) + + self.val_tokens = load_validation_tokens(h.val_files, h.eval_seq_len) + self.base_bytes_lut, self.has_leading_space_lut, self.is_boundary_token_lut = ( + build_sentencepiece_luts(self.sp, h.vocab_size, device)) + + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + # The BPB calculation assumes "▁" is its own token so that leading-space bytes + # are counted correctly. See https://github.com/openai/parameter-golf/issues/897 + assert sp.piece_to_id("\u2581") != sp.unk_id(), \ + "Tokenizer must have '▁' (space) as its own token for correct BPB byte counting" + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" int: + key = str(file) + cached = _SHARD_NTOKENS_CACHE.get(key) + if cached is not None: + return cached + header = np.fromfile(file, dtype=" np.memmap: + key = str(file) + mm = _MMAP_CACHE.get(key) + if mm is not None: + return mm + n = _read_num_tokens(file) + mm = np.memmap(file, mode="r", dtype=" int: + if n <= 1: + return 1 + while True: + s = int(self._rng.integers(1, n)) + if math.gcd(s, n) == 1: + return s + + def _reset_cursor(self, si: int, seq_len: int) -> None: + nt = int(self._num_tokens[si]) + max_phase = min(seq_len - 1, max(0, nt - seq_len - 1)) + phase = int(self._rng.integers(max_phase + 1)) if max_phase > 0 else 0 + bc = (nt - 1 - phase) // seq_len + self._cursor_phase[si] = phase + self._cursor_block_count[si] = bc + self._cursor_next[si] = 0 + self._cursor_start[si] = int(self._rng.integers(bc)) if bc > 1 else 0 + self._cursor_stride[si] = self._pick_coprime_stride(bc) + self._cursor_init[si] = True + + def _ensure_cursor(self, si: int, seq_len: int) -> None: + if not self._cursor_init[si] or self._cursor_next[si] >= self._cursor_block_count[si]: + self._reset_cursor(si, seq_len) + + def _take_from_shard(self, si: int, seq_len: int, count: int, out: list[tuple[int, int]]) -> None: + rem = count + while rem > 0: + self._ensure_cursor(si, seq_len) + bc = int(self._cursor_block_count[si]) + ni = int(self._cursor_next[si]) + take = min(rem, bc - ni) + phase = int(self._cursor_phase[si]) + start = int(self._cursor_start[si]) + stride = int(self._cursor_stride[si]) + for j in range(take): + bi = (start + (ni + j) * stride) % bc + out.append((si, phase + bi * seq_len)) + self._cursor_next[si] = ni + take + rem -= take + + def _init_pipeline(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> None: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + num_seqs = local_tokens // seq_len + global_num_seqs = num_seqs * self.world_size + self._cfg = (local_tokens, seq_len, num_seqs, global_num_seqs) + bbc = (self._num_tokens - 1) // seq_len + eligible = bbc > 0 + self._eligible_shards = np.nonzero(eligible)[0].astype(np.int64) + self._base_block_counts = bbc[self._eligible_shards].astype(np.int64) + + def _sample_global_windows(self) -> list[tuple[int, int]]: + assert self._cfg is not None and self._eligible_shards is not None + _, seq_len, _, gns = self._cfg + ec = int(self._eligible_shards.size) + progress = min(self._batches_built / 1800.0, 1.0) + remaining = np.empty(ec, dtype=np.float64) + for i, si in enumerate(self._eligible_shards.tolist()): + if self._cursor_init[si]: + r = int(self._cursor_block_count[si]) - int(self._cursor_next[si]) + remaining[i] = float(max(r, 1)) + else: + remaining[i] = float(self._base_block_counts[i]) + alpha = 0.90 - 0.40 * progress + weights = np.power(remaining, alpha) + ws = float(weights.sum()) + if not np.isfinite(ws) or ws <= 0.0: + weights = np.ones(ec, dtype=np.float64) + ws = float(weights.sum()) + probs = weights / ws + low = min(max(8, self.world_size), ec, gns) + high = min(max(32, self.world_size * 8), ec, gns) + mix = max(1, min(int(round(low + progress * (high - low))), ec, gns)) + cp = self._rng.choice(ec, size=mix, replace=False, p=probs) + cs = self._eligible_shards[cp] + cpr = probs[cp].copy() + cpr /= cpr.sum() + counts = np.ones(mix, dtype=np.int64) + extra = gns - mix + if extra > 0: + counts += self._rng.multinomial(extra, cpr).astype(np.int64) + perm = self._rng.permutation(mix) + cs, counts = cs[perm], counts[perm] + buckets: list[list[tuple[int, int]]] = [] + for si, cnt in zip(cs.tolist(), counts.tolist()): + b: list[tuple[int, int]] = [] + self._take_from_shard(int(si), seq_len, int(cnt), b) + if b: + if len(b) > 1: + bp = self._rng.permutation(len(b)) + b = [b[int(k)] for k in bp.tolist()] + buckets.append(b) + windows: list[tuple[int, int]] = [] + active = [i for i, bk in enumerate(buckets) if bk] + while active: + order = self._rng.permutation(len(active)) + new_active: list[int] = [] + for oi in order.tolist(): + bi = active[oi] + if buckets[bi]: + windows.append(buckets[bi].pop()) + if buckets[bi]: + new_active.append(bi) + active = new_active + return windows + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + if self._cfg is None: + self._init_pipeline(global_tokens, seq_len, grad_accum_steps) + _, _, num_seqs, _ = self._cfg + gw = self._sample_global_windows() + local_w = gw[self.rank::self.world_size] + x = torch.empty((num_seqs, seq_len), dtype=torch.int64) + y = torch.empty((num_seqs, seq_len), dtype=torch.int64) + for slot, (si, pos) in enumerate(local_w): + mm = _get_shard_memmap(self.files[si]) + window = torch.as_tensor(np.array(mm[pos:pos + seq_len + 1], dtype=np.int64)) + x[slot] = window[:-1] + y[slot] = window[1:] + self._batches_built += 1 + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ---------------------------------------- +# Model Architecture +# ---------------------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, + rope_base: float, qk_gain_init: float, train_seq_len: int): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len) + self.use_xsa = False + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if _use_flash_attn(): + y = _flash_attn_3_func(q, k, v, causal=True) + else: + y = F.scaled_dot_product_attention(q, k, v, is_causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + + +class ValueEmbedding(nn.Module): + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + return self.proj(F.leaky_relu(self.fc(x), negative_slope=0.5).square()) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, + rope_base: float, qk_gain_init: float, train_seq_len: int, + layer_idx: int = 0, ln_scale: bool = False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + return x_out + + +class GPT(nn.Module): + def __init__(self, h: Hyperparameters): + super().__init__() + self._ve_target_dim = h.num_kv_heads * (h.model_dim // h.num_heads) + if h.logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {h.logit_softcap}") + self.tie_embeddings = h.tie_embeddings + self.tied_embed_init_std = h.tied_embed_init_std + self.logit_softcap = h.logit_softcap + self.tok_emb = nn.Embedding(h.vocab_size, h.embedding_dim) + if h.embedding_dim != h.model_dim: + self.embed_proj = CastedLinear(h.embedding_dim, h.model_dim, bias=False) + self.head_proj = CastedLinear(h.model_dim, h.embedding_dim, bias=False) + else: + self.embed_proj = None + self.head_proj = None + self.num_encoder_layers = h.num_layers // 2 + self.num_decoder_layers = h.num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, h.model_dim, dtype=torch.float32)) + self.skip_gates = nn.Parameter(torch.zeros(self.num_skip_weights, h.model_dim, dtype=torch.float32)) if h.skip_gates_enabled else None + self.blocks = nn.ModuleList([ + Block(h.model_dim, h.num_heads, h.num_kv_heads, h.mlp_mult, h.rope_base, + h.qk_gain_init, h.train_seq_len, layer_idx=i, ln_scale=h.ln_scale) + for i in range(h.num_layers) + ]) + if h.rope_dims > 0: + head_dim = h.model_dim // h.num_heads + for block in self.blocks: + block.attn.rope_dims = h.rope_dims + block.attn.rotary = Rotary(head_dim, base=h.rope_base, train_seq_len=h.train_seq_len, rope_dims=h.rope_dims) + self.ve_layer_indices = [int(x) for x in h.ve_layers.split(",") if x.strip()] if h.ve_enabled else [] + kv_dim = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(h.vocab_size, h.ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() + self.final_norm = RMSNorm() + self.lm_head = None if h.tie_embeddings else CastedLinear(h.embedding_dim, h.vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + if h.xsa_last_n > 0: + for i in range(max(0, h.num_layers - h.xsa_last_n), h.num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + if self.embed_proj is not None: + x = self.embed_proj(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + scaled_skip = self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[i].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + x = self.final_norm(x) + if self.head_proj is not None: + x = self.head_proj(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self.forward_logits(input_ids) + return F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1), reduction="mean") + +# ---------------------------------------- +# Optimization +# ---------------------------------------- + +@torch.compile +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +class Optimizers(): + def __init__(self, h: Hyperparameters, base_model: GPT): + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in + CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in + CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.skip_gates is not None and base_model.skip_gates.numel() > 0: + scalar_params.append(base_model.skip_gates) + + token_lr = h.tied_embed_lr if h.tie_embeddings else h.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + + self.optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.embed_wd, + fused=True, + ) + self.optimizer_muon = Muon( + matrix_params, + lr=h.matrix_lr, + momentum=h.muon_momentum, + backend_steps=h.muon_backend_steps, + weight_decay=h.muon_wd, + ) + for group in self.optimizer_muon.param_groups: + group["base_lr"] = h.matrix_lr + self.optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": h.scalar_lr, "base_lr": h.scalar_lr}], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.adam_wd, + fused=True, + ) + self.optimizers: list[torch.optim.Optimizer] = [self.optimizer_tok, self.optimizer_muon, self.optimizer_scalar] + if base_model.lm_head is not None: + self.optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": h.head_lr, "base_lr": h.head_lr}], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + fused=True, + ) + self.optimizers.insert(1, self.optimizer_head) + else: + self.optimizer_head = None + + def __iter__(self): + return iter(self.optimizers) + + def zero_grad_all(self) -> None: + for opt in self.optimizers: + opt.zero_grad(set_to_none=True) + + def step(self): + for opt in self.optimizers: + opt.step() + self.zero_grad_all() + +# ---------------------------------------- +# Quantization +# ---------------------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates,ve_layer_scales,ve_shared.scale", + ).split(",") + if pattern +) +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +CODEBOOK_TARGET_PATTERNS = ( + "attn.c_q.weight", + "attn.c_k.weight", + "attn.c_v.weight", + "attn.proj.weight", + "mlp.fc.weight", + "mlp.proj.weight", +) + + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def dequantize_int8_rows(q: Tensor, s: Tensor, *, device: torch.device, dtype: torch.dtype) -> Tensor: + qf = q.to(device=device, dtype=torch.float32) + sf = s.to(device=device, dtype=torch.float32) + if sf.ndim > 0: + return (qf * sf.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype) + return (qf * float(sf.item())).to(dtype=dtype) + + +def restore_fp32_params(model: nn.Module) -> None: + """After .bfloat16(), restore CastedLinear weights and control params to FP32.""" + for module in model.modules(): + if isinstance(module, CastedLinear): + module.float() + for name, param in model.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +def _is_power_of_two(n: int) -> bool: + return n > 0 and (n & (n - 1)) == 0 + + +def _validate_codebook_hparams(h: Hyperparameters) -> None: + if not _is_power_of_two(h.codebook_block_dim): + raise ValueError(f"CODEBOOK_BLOCK_DIM must be a power of 2, got {h.codebook_block_dim}") + if h.codebook_block_dim != 8: + raise ValueError(f"The E8P codebook expects CODEBOOK_BLOCK_DIM=8, got {h.codebook_block_dim}") + if h.codebook_calibration_batches < 0: + raise ValueError( + f"CODEBOOK_CALIBRATION_BATCHES must be non-negative, got {h.codebook_calibration_batches}" + ) + if h.codebook_reserve_seconds < 0: + raise ValueError( + f"CODEBOOK_RESERVE_SECONDS must be non-negative, got {h.codebook_reserve_seconds}" + ) + if h.codebook_hessian_damp < 0: + raise ValueError( + f"CODEBOOK_HESSIAN_DAMP must be non-negative, got {h.codebook_hessian_damp}" + ) + if h.codebook_lattice_scale <= 0: + raise ValueError(f"CODEBOOK_LATTICE_SCALE must be positive, got {h.codebook_lattice_scale}") + if h.codebook_scale_bits <= 0 or h.codebook_scale_bits > 16: + raise ValueError(f"CODEBOOK_SCALE_BITS must be in [1, 16], got {h.codebook_scale_bits}") + if h.codebook_assignment_entropy_weight < 0: + raise ValueError( + f"CODEBOOK_ASSIGNMENT_ENTROPY_WEIGHT must be non-negative, got {h.codebook_assignment_entropy_weight}" + ) + if not 0.0 <= h.codebook_soft_snap_start_frac <= 1.0: + raise ValueError( + f"CODEBOOK_SOFT_SNAP_START_FRAC must be in [0, 1], got {h.codebook_soft_snap_start_frac}" + ) + if not 0.0 <= h.codebook_soft_snap_alpha <= 1.0: + raise ValueError(f"CODEBOOK_SOFT_SNAP_ALPHA must be in [0, 1], got {h.codebook_soft_snap_alpha}") + if h.codebook_soft_snap_update_every <= 0: + raise ValueError( + f"CODEBOOK_SOFT_SNAP_UPDATE_EVERY must be positive, got {h.codebook_soft_snap_update_every}" + ) + if not 0.0 <= h.codebook_penalty_start_frac <= 1.0: + raise ValueError( + f"CODEBOOK_PENALTY_START_FRAC must be in [0, 1], got {h.codebook_penalty_start_frac}" + ) + if h.codebook_penalty_weight < 0: + raise ValueError( + f"CODEBOOK_PENALTY_WEIGHT must be non-negative, got {h.codebook_penalty_weight}" + ) + if h.codebook_penalty_refresh_every <= 0: + raise ValueError( + f"CODEBOOK_PENALTY_REFRESH_EVERY must be positive, got {h.codebook_penalty_refresh_every}" + ) + if not 0.0 <= h.codebook_outlier_frac <= 1.0: + raise ValueError(f"CODEBOOK_OUTLIER_FRAC must be in [0, 1], got {h.codebook_outlier_frac}") + if h.codebook_outlier_max_count < 0: + raise ValueError(f"CODEBOOK_OUTLIER_MAX_COUNT must be non-negative, got {h.codebook_outlier_max_count}") + + +def _blockify_weight(t: Tensor, block_dim: int) -> tuple[Tensor, tuple[int, int]]: + if t.ndim != 2: + raise ValueError(f"Codebook quantization expects a 2D tensor, got shape {tuple(t.shape)}") + if t.shape[1] % block_dim != 0: + raise ValueError( + f"Tensor shape {tuple(t.shape)} is not compatible with CODEBOOK_BLOCK_DIM={block_dim}" + ) + return t.contiguous().view(-1, block_dim), (int(t.shape[0]), int(t.shape[1])) + + +def _unblockify_weight(blocks: Tensor, original_shape: tuple[int, int]) -> Tensor: + return blocks.contiguous().view(original_shape) + + +def _should_codebook_quantize(name: str, t: Tensor, h: Hyperparameters) -> bool: + return ( + t.is_floating_point() + and t.ndim == 2 + and t.shape[1] % h.codebook_block_dim == 0 + and name.startswith("blocks.") + and any(name.endswith(pattern) for pattern in CODEBOOK_TARGET_PATTERNS) + ) +def _hadamard_sign_vector(name: str, block_dim: int, *, device: torch.device, dtype: torch.dtype) -> Tensor: + rng = random.Random(f"hadamard::{name}::{block_dim}") + return torch.tensor( + [1.0 if rng.getrandbits(1) else -1.0 for _ in range(block_dim)], + device=device, + dtype=dtype, + ) + + +def hadamard_transform(x: Tensor, scale: float = 1.0) -> Tensor: + original_shape = x.shape + dim = x.shape[-1] + padded_dim = 1 << (max(dim, 1) - 1).bit_length() + out = x.reshape(-1, dim) + if padded_dim != dim: + out = F.pad(out, (0, padded_dim - dim)) + h = 1 + while h < padded_dim: + out = out.view(-1, padded_dim // (2 * h), 2, h) + a = out[:, :, 0, :] + b = out[:, :, 1, :] + out = torch.stack((a + b, a - b), dim=2).reshape(-1, padded_dim) + h *= 2 + out = out[:, :dim] + return (out * scale).reshape(*original_shape) + + +def hadamard_rotate_blocks(blocks: Tensor, sign_vec: Tensor, *, enabled: bool = True) -> Tensor: + if not enabled: + return blocks + return hadamard_transform(blocks * sign_vec, scale=blocks.shape[-1] ** -0.5) + + +def hadamard_unrotate_blocks(blocks: Tensor, sign_vec: Tensor, *, enabled: bool = True) -> Tensor: + if not enabled: + return blocks + return hadamard_transform(blocks, scale=blocks.shape[-1] ** -0.5) * sign_vec + + +def get_norm12() -> Tensor: + return torch.tensor([ + [3, 1, 1, 1, 3, 3, 3, 3], + [1, 3, 1, 1, 3, 3, 3, 3], + [1, 1, 3, 1, 3, 3, 3, 3], + [1, 1, 1, 3, 3, 3, 3, 3], + [3, 3, 3, 1, 3, 3, 1, 1], + [3, 3, 3, 1, 3, 1, 3, 1], + [3, 3, 3, 1, 1, 3, 3, 1], + [3, 3, 3, 1, 3, 1, 1, 3], + [3, 3, 3, 1, 1, 3, 1, 3], + [3, 3, 3, 1, 1, 1, 3, 3], + [3, 3, 1, 3, 3, 3, 1, 1], + [3, 3, 1, 3, 3, 1, 3, 1], + [3, 3, 1, 3, 1, 3, 3, 1], + [3, 3, 1, 3, 3, 1, 1, 3], + [3, 3, 1, 3, 1, 3, 1, 3], + [3, 3, 1, 3, 1, 1, 3, 3], + [3, 1, 3, 3, 3, 3, 1, 1], + [3, 1, 3, 3, 3, 1, 3, 1], + [3, 1, 3, 3, 1, 3, 3, 1], + [3, 1, 3, 3, 3, 1, 1, 3], + [3, 1, 3, 3, 1, 3, 1, 3], + [1, 3, 3, 3, 1, 1, 3, 3], + [1, 3, 3, 3, 3, 3, 1, 1], + [1, 3, 3, 3, 3, 1, 3, 1], + [1, 3, 3, 3, 1, 3, 3, 1], + [1, 3, 3, 3, 3, 1, 1, 3], + [1, 3, 3, 3, 1, 3, 1, 3], + [1, 1, 3, 3, 1, 3, 3, 3], + [3, 3, 1, 1, 3, 3, 3, 1], + ], dtype=torch.float32) / 2 + + +def get_packed_abs_grid() -> Tensor: + intr = torch.arange(-4, 4) + d8 = torch.cartesian_prod(*[intr] * 8).float() + 0.5 + d8_even = d8.sum(dim=-1) % 2 == 0 + d8_small = d8.norm(dim=-1).square() <= 10 + d8_abs = torch.unique(d8[torch.where(d8_even & d8_small)[0]].abs(), dim=0) + codebook_abs = torch.cat([d8_abs, get_norm12()], dim=0) + codebook_abs = codebook_abs[:, [0, 2, 4, 6, 1, 3, 5, 7]] + codebook_abs[:, 7] *= (1 - 2 * (codebook_abs.sum(1) % 2)) + codebook_abs = (codebook_abs * 2 + 8).to(torch.int32) + acc = codebook_abs[:, 0] + for i in range(7): + acc = acc | (codebook_abs[:, i + 1] << ((i + 1) * 4)) + return acc + + +def get_abs_grid() -> Tensor: + intr = torch.arange(-4, 4) + d8 = torch.cartesian_prod(*[intr] * 8).float() + 0.5 + d8_even = d8.sum(dim=-1) % 2 == 0 + d8_small = d8.norm(dim=-1).square() <= 10 + d8_abs = torch.unique(d8[torch.where(d8_even & d8_small)[0]].abs(), dim=0) + return torch.cat([d8_abs, get_norm12()], dim=0) + + +def get_full_grid(packed_abs_grid: Tensor) -> tuple[Tensor, Tensor]: + codebook = torch.zeros(1 << 16, 8) + parity_idx: list[int] = [] + shuffle_map = [0, 4, 1, 5, 2, 6, 3, 7] + for code in range(1 << 16): + signs = code & 255 + abs_idx = code >> 8 + parity = 0 + for bit in range(8): + parity ^= (signs >> bit) & 1 + signs ^= parity + abs_code = packed_abs_grid[abs_idx].item() + for i, shuffled in enumerate(shuffle_map): + codebook[code, i] = (((abs_code >> (4 * shuffled)) & 15) - 8) * 0.5 + if (signs >> shuffled) & 1: + codebook[code, i] *= -1 + if parity: + codebook[code] -= 0.25 + parity_idx.append(code) + else: + codebook[code] += 0.25 + return codebook, torch.tensor(parity_idx, dtype=torch.long) + + +_E8P_PACKED_ABS = get_packed_abs_grid() +_E8P_GRID, _E8P_PARITY_IDX = get_full_grid(_E8P_PACKED_ABS) + + +class E8P12Codebook(nn.Module): + def __init__(self) -> None: + super().__init__() + self.codesz = 8 + self.register_buffer("grid", _E8P_GRID) + self.register_buffer("grid_norm", _E8P_GRID.norm(dim=-1).square()) + grid_part = _E8P_GRID[_E8P_PARITY_IDX] + 0.25 + grid_part = grid_part[ + torch.where( + ((grid_part[:, :7] < 0).sum(dim=-1) <= 1) + & (grid_part[:, :7].min(dim=-1).values >= -0.5) + )[0] + ] + abs_grid = get_abs_grid() + self.register_buffer("grid_part", grid_part) + self.register_buffer("grid_part_norm", grid_part.norm(dim=-1).square()) + self.register_buffer("grid_abs_odd", abs_grid.sum(dim=-1) % 2 == 1) + self.register_buffer("part_abs_map", self.round(grid_part.abs(), abs_grid, abs_grid.norm(dim=-1).square())[1]) + self.register_buffer("bit_map", 2 ** torch.arange(8)) + + def round(self, x: Tensor, grid: Tensor, grid_norm: Tensor) -> tuple[Tensor, Tensor]: + idx = (2 * x @ grid.t() - grid_norm).argmax(dim=-1) + return grid[idx], idx + + def round_topk(self, x: Tensor, grid: Tensor, grid_norm: Tensor, k: int = 2) -> tuple[Tensor, Tensor]: + scores = 2 * x @ grid.t() - grid_norm + top_idx = scores.topk(k, dim=-1).indices + return grid[top_idx], top_idx + + def fast_quantize_part(self, x: Tensor, parity: bool) -> tuple[Tensor, Tensor, Tensor]: + x_part = torch.abs(x) + x_odd = torch.where((x < 0).sum(dim=-1) % 2 != 0)[0] + x_part[x_odd, 7] = -x_part[x_odd, 7] + mask = 1 - 2 * (x < 0).to(torch.float32) + mask[x_odd, 7] = -mask[x_odd, 7] + rounded, rounded_idx = self.round(x_part, self.grid_part, self.grid_part_norm) + vals = rounded * mask + err = (x - vals).norm(dim=-1) + abs_idx = self.part_abs_map[rounded_idx] + sign_mask = ((rounded < 0) ^ (mask < 0))[:, [0, 2, 4, 6, 1, 3, 5, 7]] + sign_mask[:, 7] ^= self.grid_abs_odd[abs_idx] + sign_mask[:, 0] ^= parity + idx = (abs_idx << 8) + (sign_mask * self.bit_map).sum(dim=-1).int() + return vals, idx, err + + def fast_quantize_part_topk(self, x: Tensor, parity: bool, k: int = 2) -> tuple[Tensor, Tensor, Tensor]: + x_part = torch.abs(x) + x_odd = torch.where((x < 0).sum(dim=-1) % 2 != 0)[0] + x_part[x_odd, 7] = -x_part[x_odd, 7] + mask = 1 - 2 * (x < 0).to(torch.float32) + mask[x_odd, 7] = -mask[x_odd, 7] + rounded, rounded_idx = self.round_topk(x_part, self.grid_part, self.grid_part_norm, k=k) + vals = rounded * mask.unsqueeze(1) + err = (x.unsqueeze(1) - vals).square().sum(dim=-1) + abs_idx = self.part_abs_map[rounded_idx] + sign_mask = ((rounded < 0) ^ (mask.unsqueeze(1) < 0))[:, :, [0, 2, 4, 6, 1, 3, 5, 7]] + sign_mask[:, :, 7] ^= self.grid_abs_odd[abs_idx] + sign_mask[:, :, 0] ^= parity + idx = (abs_idx << 8) + (sign_mask * self.bit_map.view(1, 1, -1)).sum(dim=-1).int() + return vals, idx, err + + def quantize(self, x: Tensor) -> tuple[Tensor, Tensor]: + plus_vals, plus_idx, plus_err = self.fast_quantize_part(x + 0.25, True) + minus_vals, minus_idx, minus_err = self.fast_quantize_part(x - 0.25, False) + use_plus = plus_err < minus_err + vals = torch.where(use_plus.unsqueeze(-1), plus_vals - 0.25, minus_vals + 0.25) + idx = torch.where(use_plus, plus_idx, minus_idx) + return vals, idx + + def quantize_top2(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: + plus_vals, plus_idx, plus_err = self.fast_quantize_part_topk(x + 0.25, True, k=2) + minus_vals, minus_idx, minus_err = self.fast_quantize_part_topk(x - 0.25, False, k=2) + cand_vals = torch.cat((plus_vals - 0.25, minus_vals + 0.25), dim=1) + cand_idx = torch.cat((plus_idx, minus_idx), dim=1) + cand_err = torch.cat((plus_err, minus_err), dim=1) + order = cand_err.topk(2, dim=1, largest=False).indices + gather_vals = cand_vals.gather(1, order.unsqueeze(-1).expand(-1, -1, self.codesz)) + gather_idx = cand_idx.gather(1, order) + gather_err = cand_err.gather(1, order) + return gather_vals[:, 0], gather_idx[:, 0], gather_vals[:, 1], gather_err[:, 1] - gather_err[:, 0] + + def decode(self, idxs: Tensor) -> Tensor: + idxs_long = idxs.long().reshape(-1) + vals = self.grid.index_select(0, idxs_long) + return vals.view(*idxs.shape, self.codesz) + + +_E8P_CODEBOOK_CACHE: dict[str, E8P12Codebook] = {} + + +def _e8p_cache_key(device: torch.device | None) -> str: + if device is None: + return "cpu" + return f"{device.type}:{device.index}" + + +def _get_e8p_codebook(device: torch.device | None) -> E8P12Codebook: + key = _e8p_cache_key(device) + cached = _E8P_CODEBOOK_CACHE.get(key) + if cached is not None: + return cached + codebook = E8P12Codebook() + if device is not None: + codebook = codebook.to(device=device, dtype=torch.float32) + else: + codebook = codebook.to(dtype=torch.float32) + codebook.eval() + _E8P_CODEBOOK_CACHE[key] = codebook + return codebook + + +@torch.no_grad() +def _quantize_e8p_blocks(blocks: Tensor, lattice_scale: float) -> tuple[Tensor, Tensor]: + codebook = _get_e8p_codebook(blocks.device) + scaled = blocks.to(dtype=torch.float32) * float(lattice_scale) + vals, idxs = codebook.quantize(scaled) + return vals / float(lattice_scale), idxs.long() + + +@torch.no_grad() +def _decode_e8p_blocks( + idxs: Tensor, + lattice_scale: float, + *, + device: torch.device, + dtype: torch.dtype, +) -> Tensor: + codebook = _get_e8p_codebook(device) + vals = codebook.decode(idxs.to(device=device)) + return (vals / float(lattice_scale)).to(dtype=dtype) + + +@torch.no_grad() +def _fast_quantize_weight_target(h: Hyperparameters, name: str, weight: Tensor) -> Tensor: + weight_f32 = weight.detach().float() + blocks, shape = _blockify_weight(weight_f32, h.codebook_block_dim) + sign_vec = _hadamard_sign_vector(name, h.codebook_block_dim, device=weight.device, dtype=torch.float32) + rotated_blocks = hadamard_rotate_blocks(blocks, sign_vec, enabled=h.codebook_use_hadamard) + block_norms = rotated_blocks.norm(dim=-1, keepdim=True).clamp_min(1e-8) + codebook = _get_e8p_codebook(weight.device) + target_radius = float(codebook.grid.norm(dim=-1).median().item()) + proxy_blocks = rotated_blocks / block_norms * target_radius + quantized_proxy, _ = _quantize_e8p_blocks(proxy_blocks, h.codebook_lattice_scale) + quantized_dirs = F.normalize(quantized_proxy.to(dtype=torch.float32), dim=-1) + recon_rotated = quantized_dirs * block_norms + recon_blocks = hadamard_unrotate_blocks( + recon_rotated, + sign_vec, + enabled=h.codebook_use_hadamard, + ) + return _unblockify_weight(recon_blocks, shape).to(dtype=weight.dtype) + + +def _log_prior_from_assignments( + assignments: Tensor, + num_entries: int, + *, + smoothing: float = 1.0, +) -> Tensor: + flat = assignments.reshape(-1).to(dtype=torch.int64) + counts = torch.bincount(flat, minlength=num_entries).to(dtype=torch.float32) + probs = (counts + float(smoothing)) / float(counts.sum().item() + smoothing * num_entries) + return torch.log(probs) + + +@torch.no_grad() +def _metric_optimal_codewords( + rotated_blocks: Tensor, + metrics: Tensor, + candidates: Tensor, + *, + candidate_batch_size: int = 4096, + assignment_log_prior: Tensor | None = None, + assignment_entropy_weight: float = 0.0, +) -> tuple[Tensor, Tensor]: + if rotated_blocks.ndim != 3: + raise ValueError(f"Expected rotated blocks with shape [rows, positions, dim], got {tuple(rotated_blocks.shape)}") + expected_metrics_shape = (rotated_blocks.shape[1], rotated_blocks.shape[2], rotated_blocks.shape[2]) + if metrics.ndim != 3 or metrics.shape != expected_metrics_shape: + raise ValueError(f"Expected metrics with shape {expected_metrics_shape}, got {tuple(metrics.shape)}") + if candidates.ndim != 2 or candidates.shape[1] != rotated_blocks.shape[2]: + raise ValueError(f"Expected candidates [entries, {rotated_blocks.shape[2]}], got {tuple(candidates.shape)}") + if assignment_log_prior is not None and assignment_log_prior.shape != (candidates.shape[0],): + raise ValueError( + f"Expected assignment_log_prior with shape ({candidates.shape[0]},), got {tuple(assignment_log_prior.shape)}" + ) + num_rows, num_positions, _ = rotated_blocks.shape + selected_idx = torch.empty((num_rows, num_positions), device=rotated_blocks.device, dtype=torch.long) + selected_codewords = torch.empty_like(rotated_blocks, dtype=torch.float32) + for pos in range(num_positions): + metric = metrics[pos] + x = rotated_blocks[:, pos, :].to(dtype=torch.float32) + x_metric = torch.matmul(x, metric) + best_improvement = torch.full((num_rows,), -torch.inf, device=x.device, dtype=torch.float32) + best_idx = torch.zeros((num_rows,), device=x.device, dtype=torch.long) + for start in range(0, candidates.shape[0], candidate_batch_size): + cand = candidates[start:start + candidate_batch_size] + cand_metric = torch.matmul(cand, metric) + numer = torch.matmul(x_metric, cand.t()) + denom = (cand_metric * cand).sum(dim=-1).clamp_min(1e-8) + # x^T M x is candidate-independent, so minimizing the optimal metric error is + # equivalent to maximizing the gain from the best non-negative scale. + improvement = numer.clamp_min(0.0).square() / denom.unsqueeze(0) + if assignment_log_prior is not None and assignment_entropy_weight > 0: + improvement = improvement + float(assignment_entropy_weight) * assignment_log_prior[ + start:start + cand.shape[0] + ].unsqueeze(0) + chunk_best_improvement, chunk_best_idx = improvement.max(dim=1) + update = chunk_best_improvement > best_improvement + if update.any(): + best_improvement = torch.where(update, chunk_best_improvement, best_improvement) + best_idx = torch.where(update, start + chunk_best_idx, best_idx) + selected_idx[:, pos] = best_idx + selected_codewords[:, pos, :] = candidates.index_select(0, best_idx) + return selected_codewords, selected_idx + + +def collect_hessians( + model: nn.Module, + train_loader: DistributedTokenLoader, + h: Hyperparameters, + device: torch.device, + target_names: set[str], + n_calibration_batches: int = 64, +) -> dict[str, Tensor]: + """Run calibration batches and collect H = X^T X for selected CastedLinear layers.""" + if n_calibration_batches <= 0 or not target_names: + return {} + hessians: dict[str, Tensor] = {} + hooks = [] + was_training = model.training + + def make_hook(name: str): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + return hook_fn + + for module_name, module in model.named_modules(): + if not isinstance(module, CastedLinear): + continue + weight_name = f"{module_name}.weight" + if weight_name in target_names: + hooks.append(module.register_forward_hook(make_hook(weight_name))) + + model.eval() + with torch.no_grad(): + for _ in range(n_calibration_batches): + x, _ = train_loader.next_batch( + h.train_batch_tokens, + h.train_seq_len, + h.grad_accum_steps, + ) + model.forward_logits(x) + + for hook in hooks: + hook.remove() + + world = dist.get_world_size() if dist.is_available() and dist.is_initialized() else 1 + for name, hessian in hessians.items(): + if world > 1: + dist.all_reduce(hessian, op=dist.ReduceOp.SUM) + hessians[name] = hessian.cpu() / float(max(world * n_calibration_batches, 1)) + + if was_training: + model.train() + return hessians + + +def _identity_block_metrics(shape: tuple[int, int], block_dim: int, device: torch.device) -> Tensor: + num_positions = shape[1] // block_dim + eye = torch.eye(block_dim, device=device, dtype=torch.float32) + return eye.unsqueeze(0).repeat(num_positions, 1, 1) + + +def _block_metrics_from_hessian( + shape: tuple[int, int], + block_dim: int, + sign_vec: Tensor, + hessian: Tensor | None, + *, + use_hadamard: bool, + damp_factor: float, + device: torch.device, +) -> Tensor: + if hessian is None or hessian.ndim != 2 or hessian.shape != (shape[1], shape[1]): + return _identity_block_metrics(shape, block_dim, device) + num_positions = shape[1] // block_dim + hessian_work = hessian.to(device=device, dtype=torch.float32).clone() + diag = hessian_work.diag() + dead = diag <= 0 + if dead.any(): + hessian_work[dead, dead] = 1.0 + damp = float(damp_factor * hessian_work.diag().mean().clamp_min(1e-8).item()) + hessian_work.diagonal().add_(damp) + reshaped = hessian_work.view(num_positions, block_dim, num_positions, block_dim) + block_idx = torch.arange(num_positions, device=device) + block_metrics = reshaped[block_idx, :, block_idx, :] + if use_hadamard: + unrotate = hadamard_unrotate_blocks( + torch.eye(block_dim, device=device, dtype=torch.float32), + sign_vec, + enabled=True, + ) + block_metrics = torch.matmul( + unrotate.unsqueeze(0), + torch.matmul(block_metrics, unrotate.t().unsqueeze(0)), + ) + block_metrics = 0.5 * (block_metrics + block_metrics.transpose(-1, -2)) + block_metrics.diagonal(dim1=-2, dim2=-1).add_(1e-6) + return block_metrics + + +def _metric_optimal_scales(rotated_blocks: Tensor, codebook_blocks: Tensor, metrics: Tensor) -> Tensor: + numer = torch.einsum("npd,pde,npe->np", codebook_blocks, metrics, rotated_blocks) + denom = torch.einsum("npd,pde,npe->np", codebook_blocks, metrics, codebook_blocks).clamp_min(1e-8) + scales = (numer / denom).clamp_min(1e-8) + max_fp16 = float(torch.finfo(torch.float16).max) + return scales.clamp(max=max_fp16).unsqueeze(-1) + + +@torch.no_grad() +def _metric_optimal_e8p_codewords( + rotated_blocks: Tensor, + metrics: Tensor, + lattice_scale: float, + *, + candidate_batch_size: int = 4096, + assignment_log_prior: Tensor | None = None, + assignment_entropy_weight: float = 0.0, +) -> tuple[Tensor, Tensor]: + codebook = _get_e8p_codebook(rotated_blocks.device) + candidates = codebook.grid.to(dtype=torch.float32) / float(lattice_scale) + return _metric_optimal_codewords( + rotated_blocks, + metrics, + candidates, + candidate_batch_size=candidate_batch_size, + assignment_log_prior=assignment_log_prior, + assignment_entropy_weight=assignment_entropy_weight, + ) + + +class FixedCodebookSoftSnap: + def __init__(self, h: Hyperparameters, model: nn.Module): + _validate_codebook_hparams(h) + self.h = h + self.target_modules: list[tuple[str, CastedLinear]] = [] + self.target_weights = 0 + self.active = False + self.last_snap_step: int | None = None + for module_name, module in model.named_modules(): + if not isinstance(module, CastedLinear): + continue + weight_name = f"{module_name}.weight" + if _should_codebook_quantize(weight_name, module.weight, h): + self.target_modules.append((weight_name, module)) + self.target_weights += int(module.weight.numel()) + + def should_activate(self, frac: float) -> bool: + return self.h.codebook_soft_snap_enabled and not self.active and frac >= self.h.codebook_soft_snap_start_frac + + @torch.no_grad() + def apply(self, step: int, *, force: bool = False) -> bool: + if not self.active or not self.target_modules: + return False + if not force and self.last_snap_step is not None: + if (step - self.last_snap_step) < self.h.codebook_soft_snap_update_every: + return False + alpha = float(self.h.codebook_soft_snap_alpha) + if alpha <= 0.0: + return False + for name, module in self.target_modules: + q_weight = _fast_quantize_weight_target(self.h, name, module.weight) + module.weight.lerp_(q_weight, alpha) + self.last_snap_step = step + return True + + def activate(self, step: int, frac: float) -> None: + if self.active: + return + self.active = True + if self.h.is_main_process: + log( + f"codebook_soft_snap:enabled step:{step}/{self.h.iterations} " + f"start_frac:{self.h.codebook_soft_snap_start_frac:.3f} frac:{frac:.3f} " + f"mode:fast_direction " + f"alpha:{self.h.codebook_soft_snap_alpha:.4f} " + f"target_tensors:{len(self.target_modules)} target_weights:{self.target_weights} " + f"update_every:{self.h.codebook_soft_snap_update_every}" + ) + + +class FixedCodebookPenalty: + def __init__(self, h: Hyperparameters, model: nn.Module): + _validate_codebook_hparams(h) + self.h = h + self.target_modules: list[tuple[str, CastedLinear]] = [] + self.target_weights = 0 + self.active = False + self.last_refresh_step: int | None = None + self.cached_targets: list[dict[str, object]] = [] + for module_name, module in model.named_modules(): + if not isinstance(module, CastedLinear): + continue + weight_name = f"{module_name}.weight" + if _should_codebook_quantize(weight_name, module.weight, h): + self.target_modules.append((weight_name, module)) + self.target_weights += int(module.weight.numel()) + + def should_activate(self, frac: float) -> bool: + return self.h.codebook_penalty_enabled and not self.active and frac >= self.h.codebook_penalty_start_frac + + def activate(self, step: int, frac: float) -> None: + if self.active: + return + self.active = True + if self.h.is_main_process: + log( + f"codebook_penalty:enabled step:{step}/{self.h.iterations} " + f"start_frac:{self.h.codebook_penalty_start_frac:.3f} frac:{frac:.3f} " + f"weight:{self.h.codebook_penalty_weight:.5f} " + f"mode:fast_direction_target " + f"refresh_every:{self.h.codebook_penalty_refresh_every} " + f"target_tensors:{len(self.target_modules)} target_weights:{self.target_weights}" + ) + + def _current_weight(self, frac: float) -> float: + if frac <= self.h.codebook_penalty_start_frac: + return 0.0 + return float(self.h.codebook_penalty_weight) + + @torch.no_grad() + def _refresh_targets(self, step: int) -> None: + refreshed: list[dict[str, object]] = [] + for name, module in self.target_modules: + q_weight = _fast_quantize_weight_target(self.h, name, module.weight) + refreshed.append( + { + "name": name, + "module": module, + "target_weight": q_weight.detach(), + } + ) + self.cached_targets = refreshed + self.last_refresh_step = step + + def penalty(self, step: int, frac: float) -> Tensor | None: + if not self.active: + return None + weight = self._current_weight(frac) + if weight <= 0.0: + return None + if self.last_refresh_step is None or (step - self.last_refresh_step) >= self.h.codebook_penalty_refresh_every: + self._refresh_targets(step) + if not self.cached_targets: + return None + losses: list[Tensor] = [] + for item in self.cached_targets: + module = item["module"] + target_weight = item["target_weight"] + diff = module.weight.float() - target_weight.float() + mse = diff.square().mean() + denom = target_weight.float().square().mean().clamp_min(1e-8) + losses.append(mse / denom) + if not losses: + return None + return float(weight) * torch.stack(losses).mean() + + +class CodebookQuantizer: + def __init__(self, h: Hyperparameters, model: nn.Module): + _validate_codebook_hparams(h) + self.h = h + self.states: dict[str, dict[str, object]] = {} + self.target_modules: list[tuple[str, CastedLinear]] = [] + self.target_names: set[str] = set() + self.last_fit_summary: dict[str, float] = {} + for module_name, module in model.named_modules(): + if not isinstance(module, CastedLinear): + continue + weight_name = f"{module_name}.weight" + if _should_codebook_quantize(weight_name, module.weight, h): + self.target_modules.append((weight_name, module)) + self.target_names.add(weight_name) + + @torch.no_grad() + def _prepare_rotated_blocks( + self, + name: str, + module: CastedLinear, + ) -> tuple[Tensor, tuple[int, int], Tensor]: + weight = module.weight.detach().float() + blocks, shape = _blockify_weight(weight, self.h.codebook_block_dim) + sign_vec = _hadamard_sign_vector(name, self.h.codebook_block_dim, device=blocks.device, dtype=torch.float32) + rotated_blocks = hadamard_rotate_blocks(blocks, sign_vec, enabled=self.h.codebook_use_hadamard) + return rotated_blocks, shape, sign_vec + + def _fit_tensor(self, name: str, module: CastedLinear, hessian: Tensor | None) -> tuple[str, dict[str, float]]: + weight = module.weight.detach().float() + rotated_blocks, shape, sign_vec = self._prepare_rotated_blocks(name, module) + blocks, _ = _blockify_weight(weight, self.h.codebook_block_dim) + num_rows, num_cols = shape + num_positions = num_cols // self.h.codebook_block_dim + rotated_grid = rotated_blocks.view(num_rows, num_positions, self.h.codebook_block_dim) + metrics = _block_metrics_from_hessian( + shape, + self.h.codebook_block_dim, + sign_vec, + hessian, + use_hadamard=self.h.codebook_use_hadamard, + damp_factor=self.h.codebook_hessian_damp, + device=blocks.device, + ) + initial_dirs, initial_idx = _metric_optimal_e8p_codewords( + rotated_grid, + metrics, + self.h.codebook_lattice_scale, + ) + if self.h.codebook_assignment_entropy_weight > 0: + e8p_entries = int(_get_e8p_codebook(rotated_blocks.device).grid.shape[0]) + assignment_log_prior = _log_prior_from_assignments(initial_idx, e8p_entries) + fixed_dirs, fixed_idx = _metric_optimal_e8p_codewords( + rotated_grid, + metrics, + self.h.codebook_lattice_scale, + assignment_log_prior=assignment_log_prior, + assignment_entropy_weight=self.h.codebook_assignment_entropy_weight, + ) + else: + assignment_log_prior = None + fixed_dirs, fixed_idx = initial_dirs, initial_idx + scales = _metric_optimal_scales(rotated_grid, fixed_dirs, metrics) + scales_fp16 = scales.to(dtype=torch.float16) + recon_rotated = fixed_dirs * scales_fp16.to(dtype=torch.float32) + recon_blocks = hadamard_unrotate_blocks( + recon_rotated.view(-1, self.h.codebook_block_dim), + sign_vec, + enabled=self.h.codebook_use_hadamard, + ) + diff = blocks - recon_blocks + mse = diff.square().mean() + energy = blocks.square().mean().clamp_min(1e-12) + stats = { + "num_weights": float(weight.numel()), + "mse": float(mse.item()), + "rel_mse": float((mse / energy).item()), + "scale_min": float(scales_fp16.min().item()), + "scale_mean": float(scales_fp16.float().mean().item()), + "scale_max": float(scales_fp16.max().item()), + } + state: dict[str, object] = { + "shape": shape, + "block_dim": self.h.codebook_block_dim, + "fixed_idx": fixed_idx.reshape(-1).to(dtype=torch.uint16).cpu().contiguous(), + "scales": scales_fp16.reshape(-1).cpu().contiguous(), + "stats": stats, + } + if assignment_log_prior is not None: + probs = assignment_log_prior.exp() + stats["assignment_prior_entropy"] = float((-(probs * assignment_log_prior).sum()).item()) + self.states[name] = state + return name, stats + + @torch.no_grad() + def fit(self, model: nn.Module, hessians: dict[str, Tensor]) -> None: + if not self.target_modules: + if self.h.is_main_process: + log("codebook:no eligible tensors found") + return + total_fp_weights = sum( + int(t.numel()) + for _, t in model.state_dict().items() + if t.is_floating_point() + ) + fit_items = [self._fit_tensor(name, module, hessians.get(name)) for name, module in self.target_modules] + target_weights = sum(int(stats["num_weights"]) for _, stats in fit_items) + weighted_rel_mse = sum(float(stats["rel_mse"]) * int(stats["num_weights"]) for _, stats in fit_items) + self.last_fit_summary = { + "target_tensors": float(len(fit_items)), + "target_weights": float(target_weights), + "coverage": target_weights / max(total_fp_weights, 1), + "rel_mse": weighted_rel_mse / max(target_weights, 1), + "target_bpw": 4.0, + } + if self.h.is_main_process: + log( + f"codebook:fit target_tensors:{len(fit_items)} target_weights:{target_weights} " + f"coverage:{self.last_fit_summary['coverage']:.4%} rel_mse:{self.last_fit_summary['rel_mse']:.6e} " + f"target_bpw:{self.last_fit_summary['target_bpw']:.4f}" + ) + + @torch.no_grad() + def _reconstruct_tensor_from_state(self, name: str, state: dict[str, object], *, dtype: torch.dtype = torch.float32) -> Tensor: + shape = tuple(int(x) for x in state["shape"]) + block_dim = int(state["block_dim"]) + idx = state["fixed_idx"].to(dtype=torch.int64) + scales = state["scales"].to(dtype=torch.float32).view(-1, 1) + fixed_blocks = _decode_e8p_blocks( + idx, + float(self.h.codebook_lattice_scale), + device=torch.device("cpu"), + dtype=torch.float32, + ) + rotated_blocks = fixed_blocks * scales + sign_vec = _hadamard_sign_vector(name, block_dim, device=torch.device("cpu"), dtype=torch.float32) + blocks = hadamard_unrotate_blocks( + rotated_blocks, + sign_vec, + enabled=bool(self.h.codebook_use_hadamard), + ) + return _unblockify_weight(blocks, shape).to(dtype) + + @torch.no_grad() + def _select_outliers(self, original: Tensor, reconstructed: Tensor) -> tuple[Tensor | None, Tensor | None, Tensor | None]: + requested = 0 + if self.h.codebook_outlier_frac > 0: + requested = int(math.ceil(float(self.h.codebook_outlier_frac) * float(original.numel()))) + if self.h.codebook_outlier_max_count > 0: + requested = ( + min(requested, int(self.h.codebook_outlier_max_count)) + if requested > 0 + else int(self.h.codebook_outlier_max_count) + ) + requested = min(requested, int(original.numel())) + if requested <= 0: + return None, None, None + flat_orig = original.reshape(-1).to(dtype=torch.float32) + flat_recon = reconstructed.reshape(-1).to(dtype=torch.float32) + err = (flat_orig - flat_recon).square() + if requested >= flat_orig.numel(): + idx = torch.arange(flat_orig.numel(), dtype=torch.int32) + else: + idx = torch.topk(err, k=requested, largest=True, sorted=False).indices.to(dtype=torch.int64) + idx = idx.index_select(0, torch.argsort(idx)).to(dtype=torch.int32).cpu() + vals = flat_orig.index_select(0, idx.to(dtype=torch.int64)).cpu().contiguous() + max_abs = float(vals.abs().max().item()) if vals.numel() > 0 else 0.0 + if max_abs <= 1e-12: + scale = torch.ones((1,), dtype=torch.float16) + q = torch.zeros(vals.numel(), dtype=torch.int8) + else: + scale_value = max_abs / 127.0 + q = torch.round(vals / scale_value).clamp_(-127, 127).to(dtype=torch.int8).contiguous() + scale = torch.tensor([scale_value], dtype=torch.float16) + return idx.contiguous(), q.cpu().contiguous(), scale.cpu().contiguous() + + def build_export( + self, + state_dict: dict[str, Tensor], + ) -> tuple[dict[str, Tensor], dict[str, object], dict[str, object]]: + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + export_stats: dict[str, object] = {"tensors": {}} + total_payload_bytes = 0 + target_payload_bytes = 0 + outlier_payload_bytes = 0 + outlier_weights = 0 + int8_fallback_payload_bytes = 0 + int8_fallback_weights = 0 + passthrough_payload_bytes = 0 + passthrough_weights = 0 + target_weights = 0 + total_fp_weights = sum(int(t.numel()) for _, t in state_dict.items() if t.is_floating_point()) + + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + if name in self.states: + state = self.states[name] + idx = state["fixed_idx"] + scales = state["scales"] + scale_payload, scale_meta = _quantize_codebook_scales(scales, self.h.codebook_scale_bits) + result[name + ".idx"] = idx + result[name + ".scale"] = scale_payload + tensor_outlier_count = 0 + tensor_outlier_payload_bytes = 0 + if self.h.codebook_outlier_frac > 0 or self.h.codebook_outlier_max_count > 0: + reconstructed = self._reconstruct_tensor_from_state(name, state, dtype=torch.float32) + outlier_idx, outlier_q, outlier_scale = self._select_outliers(t.float(), reconstructed) + if ( + outlier_idx is not None + and outlier_q is not None + and outlier_scale is not None + and outlier_idx.numel() > 0 + ): + result[name + ".outlier_idx"] = outlier_idx + result[name + ".outlier_q"] = outlier_q + result[name + ".outlier_scale"] = outlier_scale + tensor_outlier_count = int(outlier_idx.numel()) + tensor_outlier_payload_bytes = ( + outlier_idx.numel() * outlier_idx.element_size() + + outlier_q.numel() * outlier_q.element_size() + + outlier_scale.numel() * outlier_scale.element_size() + ) + meta[name] = { + "type": "codebook_e8p", + "shape": list(state["shape"]), + "block_dim": int(state["block_dim"]), + "hadamard": bool(self.h.codebook_use_hadamard), + "lattice_scale": float(self.h.codebook_lattice_scale), + "outlier_count": tensor_outlier_count, + "outlier_format": "int8" if tensor_outlier_count > 0 else "none", + **scale_meta, + } + payload_bytes = ( + idx.numel() * idx.element_size() + + scale_payload.numel() * scale_payload.element_size() + + tensor_outlier_payload_bytes + ) + total_payload_bytes += payload_bytes + target_payload_bytes += payload_bytes + outlier_payload_bytes += tensor_outlier_payload_bytes + outlier_weights += tensor_outlier_count + target_weights += int(state["stats"]["num_weights"]) + export_stats["tensors"][name] = { + **state["stats"], + "payload_bytes": payload_bytes, + "outlier_payload_bytes": tensor_outlier_payload_bytes, + "outlier_count": tensor_outlier_count, + "scale_bits": int(scale_meta["scale_bits"]), + "scale_format": str(scale_meta["scale_format"]), + "bpw": (8.0 * payload_bytes) / max(float(state["stats"]["num_weights"]), 1.0), + } + export_stats.setdefault("diagnostics", []).append( + _codebook_payload_diagnostics( + name, + idx, + scales, + scale_payload, + tuple(int(x) for x in state["shape"]), + int(state["block_dim"]), + self.h.compressor, + ) + ) + continue + + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + payload_bytes = result[name].numel() * result[name].element_size() + total_payload_bytes += payload_bytes + passthrough_payload_bytes += payload_bytes + if t.is_floating_point(): + passthrough_weights += int(t.numel()) + continue + + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + payload_bytes = result[name].numel() * result[name].element_size() + total_payload_bytes += payload_bytes + passthrough_payload_bytes += payload_bytes + passthrough_weights += int(t.numel()) + continue + + q, s = quantize_float_tensor(t) + result[name + ".q"] = q.cpu().contiguous() + result[name + ".scale"] = s.cpu().contiguous() + meta[name] = {"type": "int8"} + payload_bytes = ( + result[name + ".q"].numel() * result[name + ".q"].element_size() + + result[name + ".scale"].numel() * result[name + ".scale"].element_size() + ) + total_payload_bytes += payload_bytes + int8_fallback_payload_bytes += payload_bytes + int8_fallback_weights += int(t.numel()) + + export_stats["summary"] = { + "target_tensors": len(self.states), + "target_weights": target_weights, + "coverage": target_weights / max(total_fp_weights, 1), + "target_payload_bytes": target_payload_bytes, + "target_bpw": (8.0 * target_payload_bytes) / max(target_weights, 1), + "outlier_weights": outlier_weights, + "outlier_payload_bytes": outlier_payload_bytes, + "int8_fallback_weights": int8_fallback_weights, + "int8_fallback_payload_bytes": int8_fallback_payload_bytes, + "passthrough_weights": passthrough_weights, + "passthrough_payload_bytes": passthrough_payload_bytes, + "payload_bytes_before_torchsave": total_payload_bytes, + "effective_payload_bpw_all_weights": (8.0 * total_payload_bytes) / max(total_fp_weights, 1), + "model_fp_weights": total_fp_weights, + } + return result, meta, export_stats + + +def dequantize_state_dict_codebook( + result: dict[str, Tensor], + meta: dict[str, object], + template_sd: dict[str, Tensor], +) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + if info in ("passthrough", "passthrough_ctrl"): + t = result[name] + if t.dtype == torch.float16 and orig.dtype in (torch.float32, torch.bfloat16): + t = t.to(orig.dtype) + out[name] = t + continue + if not isinstance(info, dict): + raise ValueError(f"Unsupported compression metadata for {name}: {info!r}") + if info.get("type") == "int8": + out[name] = dequantize_int8_rows( + result[name + ".q"], + result[name + ".scale"], + device=torch.device("cpu"), + dtype=orig.dtype, + ) + continue + if info.get("type") not in ("codebook_e8p_fp16_scale", "codebook_e8p"): + raise ValueError(f"Unsupported compression metadata for {name}: {info!r}") + shape = tuple(int(x) for x in info["shape"]) + block_dim = int(info["block_dim"]) + idx = result[name + ".idx"].to(dtype=torch.int64) + if info.get("type") == "codebook_e8p_fp16_scale": + scales = result[name + ".scale"].to(dtype=torch.float32).view(-1, 1) + else: + scales = _dequantize_codebook_scales(result[name + ".scale"], info) + fixed_blocks = _decode_e8p_blocks( + idx, + float(info["lattice_scale"]), + device=torch.device("cpu"), + dtype=torch.float32, + ) + rotated_blocks = fixed_blocks * scales + sign_vec = _hadamard_sign_vector(name, block_dim, device=torch.device("cpu"), dtype=torch.float32) + blocks = hadamard_unrotate_blocks( + rotated_blocks, + sign_vec, + enabled=bool(info.get("hadamard", True)), + ) + restored = _unblockify_weight(blocks, shape).to(orig.dtype) + if int(info.get("outlier_count", 0)) > 0: + flat = restored.reshape(-1) + outlier_idx = result[name + ".outlier_idx"].to(dtype=torch.int64) + outlier_q = result[name + ".outlier_q"].to(dtype=torch.float32).reshape(-1) + outlier_scale = float(result[name + ".outlier_scale"].to(dtype=torch.float32).reshape(-1)[0].item()) + outlier_val = (outlier_q * outlier_scale).to(dtype=orig.dtype) + flat[outlier_idx] = outlier_val + restored = flat.view_as(restored) + out[name] = restored + return out + + +_BSHF_MAGIC = b"BSHF" + + +def _byte_shuffle(data: bytes, stride: int = 2) -> bytes: + """Transpose byte stream by stride position for better compression.""" + if stride <= 1 or len(data) < stride: + return data + src = np.frombuffer(data, dtype=np.uint8) + n = len(src) + out = np.empty(n, dtype=np.uint8) + dest_off = 0 + for pos in range(stride): + chunk = src[pos::stride] + out[dest_off:dest_off + len(chunk)] = chunk + dest_off += len(chunk) + return _BSHF_MAGIC + bytes([stride]) + out.tobytes() + + +def _byte_unshuffle(data: bytes) -> bytes: + """Inverse of _byte_shuffle. Auto-detects BSHF magic header.""" + if len(data) < 5 or data[:4] != _BSHF_MAGIC: + return data + stride = data[4] + if stride < 2: + return data[5:] + payload = np.frombuffer(data, dtype=np.uint8, offset=5) + n = len(payload) + out = np.empty(n, dtype=np.uint8) + src_off = 0 + for pos in range(stride): + chunk_len = n // stride + (1 if pos < n % stride else 0) + out[pos::stride][:chunk_len] = payload[src_off:src_off + chunk_len] + src_off += chunk_len + return out.tobytes() + + +def _compress(data: bytes, compressor: str, byte_shuffle: bool = True) -> bytes: + if byte_shuffle: + data = _byte_shuffle(data) + if compressor == "lzma": + return lzma.compress(data, preset=6) + elif compressor == "brotli": + import brotli + return brotli.compress(data, quality=11) + raise ValueError(f"Unknown compressor: {compressor!r}") + + +def _decompress(data: bytes, compressor: str, byte_shuffle: bool = True) -> bytes: + if compressor == "lzma": + raw = lzma.decompress(data) + elif compressor == "brotli": + import brotli + raw = brotli.decompress(data) + else: + raise ValueError(f"Unknown compressor: {compressor!r}") + if byte_shuffle: + raw = _byte_unshuffle(raw) + return raw + + +def _codebook_scale_dtype(bits: int) -> torch.dtype: + return torch.uint8 if bits <= 8 else torch.uint16 + + +@torch.no_grad() +def _quantize_codebook_scales(scales: Tensor, bits: int) -> tuple[Tensor, dict[str, object]]: + scales_fp32 = scales.to(dtype=torch.float32).clamp_min(1e-12) + if bits >= 16: + return scales.to(dtype=torch.float16).contiguous(), { + "scale_format": "fp16", + "scale_bits": 16, + } + + levels = (1 << bits) - 1 + log_scales = torch.log2(scales_fp32) + log_min = float(log_scales.min().item()) + log_max = float(log_scales.max().item()) + if log_max <= log_min: + scale_codes = torch.zeros_like(scales_fp32, dtype=_codebook_scale_dtype(bits)) + log_step = 0.0 + else: + log_step = (log_max - log_min) / levels + scale_codes = torch.round((log_scales - log_min) / log_step).clamp_(0, levels).to(_codebook_scale_dtype(bits)) + return scale_codes.contiguous().cpu(), { + "scale_format": "log", + "scale_bits": int(bits), + "scale_log_min": log_min, + "scale_log_step": float(log_step), + } + + +@torch.no_grad() +def _dequantize_codebook_scales(scale_payload: Tensor, info: dict[str, object]) -> Tensor: + scale_format = str(info.get("scale_format", "fp16")) + if scale_format == "fp16": + return scale_payload.to(dtype=torch.float32).view(-1, 1) + if scale_format != "log": + raise ValueError(f"Unsupported codebook scale format: {scale_format!r}") + log_min = float(info["scale_log_min"]) + log_step = float(info["scale_log_step"]) + if log_step == 0.0: + return torch.full( + (scale_payload.numel(), 1), + 2.0 ** log_min, + dtype=torch.float32, + ) + logs = log_min + scale_payload.to(dtype=torch.float32).reshape(-1, 1) * log_step + return torch.pow(2.0, logs) + + + + +def _entropy_from_counts(counts: np.ndarray) -> float: + total = int(counts.sum()) + if total <= 0: + return 0.0 + probs = counts[counts > 0].astype(np.float64) / float(total) + return float(-(probs * np.log2(probs)).sum()) + + +def _entropy_from_u8_bytes(data: bytes) -> float: + if not data: + return 0.0 + arr = np.frombuffer(data, dtype=np.uint8) + counts = np.bincount(arr, minlength=256) + return _entropy_from_counts(counts) + + +def _adjacent_corrcoef(arr: np.ndarray) -> float: + if arr.size < 2: + return 1.0 + x0 = arr[:-1].astype(np.float64, copy=False) + x1 = arr[1:].astype(np.float64, copy=False) + x0 = x0 - x0.mean() + x1 = x1 - x1.mean() + denom = math.sqrt(float(np.dot(x0, x0) * np.dot(x1, x1))) + if denom <= 0.0: + return 1.0 + return float(np.dot(x0, x1) / denom) + + +def _codebook_payload_diagnostics( + name: str, + idx: Tensor, + scales: Tensor, + scale_payload: Tensor, + shape: tuple[int, int], + block_dim: int, + compressor: str, +) -> dict[str, float | int | str]: + idx_np = idx.contiguous().numpy().reshape(-1) + idx_codes = idx_np.astype(np.int64, copy=False) + scales_np = scales.contiguous().numpy().reshape(-1).astype(np.float16, copy=False) + scale_payload_np = scale_payload.contiguous().numpy().reshape(-1) + idx_bytes = idx_np.tobytes() + scale_bytes = scale_payload_np.tobytes() + combined_bytes = idx_bytes + scale_bytes + + idx_max = int(idx_codes.max()) if idx_codes.size else 0 + idx_counts = np.bincount(idx_codes, minlength=max(idx_max + 1, 1)) + idx_entropy_bits = _entropy_from_counts(idx_counts) + idx_unique = int(np.count_nonzero(idx_counts)) + + scale_words = scales_np.view(np.uint16) + _, scale_counts = np.unique(scale_words, return_counts=True) + scale_value_entropy_bits = _entropy_from_counts(scale_counts) + + scale_grid = scales_np.astype(np.float32, copy=False).reshape(shape[0], shape[1] // block_dim) + if scale_grid.shape[1] > 1: + scale_deltas = np.diff(scale_grid, axis=1).astype(np.float16, copy=False).reshape(-1) + scale_delta_words = scale_deltas.view(np.uint16) + _, delta_counts = np.unique(scale_delta_words, return_counts=True) + scale_delta_entropy_bits = _entropy_from_counts(delta_counts) + scale_delta_mean_abs = float(np.mean(np.abs(scale_deltas.astype(np.float32, copy=False)))) + scale_adjacent_corr = _adjacent_corrcoef(scale_grid.reshape(-1)) + scale_relative_delta = scale_delta_mean_abs / max(float(np.mean(np.abs(scale_grid))), 1e-12) + else: + scale_delta_entropy_bits = 0.0 + scale_delta_mean_abs = 0.0 + scale_adjacent_corr = 1.0 + scale_relative_delta = 0.0 + + idx_compressed_bytes = len(_compress(idx_bytes, compressor)) + scale_compressed_bytes = len(_compress(scale_bytes, compressor)) + combined_compressed_bytes = len(_compress(combined_bytes, compressor)) + idx_byte_entropy = _entropy_from_u8_bytes(_byte_shuffle(idx_bytes)) + scale_byte_entropy = _entropy_from_u8_bytes(_byte_shuffle(scale_bytes)) + + return { + "name": name, + "idx_unique": idx_unique, + "idx_entropy_bits": idx_entropy_bits, + "idx_byte_entropy_bits": idx_byte_entropy, + "idx_raw_bytes": len(idx_bytes), + "idx_compressed_bytes": idx_compressed_bytes, + "idx_entropy_bound_bytes": (idx_np.size * idx_entropy_bits) / 8.0, + "scale_unique": int(scale_counts.size), + "scale_value_entropy_bits": scale_value_entropy_bits, + "scale_delta_entropy_bits": scale_delta_entropy_bits, + "scale_byte_entropy_bits": scale_byte_entropy, + "scale_raw_bytes": len(scale_bytes), + "scale_compressed_bytes": scale_compressed_bytes, + "scale_entropy_bound_bytes": (scales_np.size * scale_value_entropy_bits) / 8.0, + "scale_delta_mean_abs": scale_delta_mean_abs, + "scale_relative_delta": scale_relative_delta, + "scale_adjacent_corr": scale_adjacent_corr, + "scale_storage_bits": float(8.0 * len(scale_bytes) / max(scale_payload_np.size, 1)), + "combined_raw_bytes": len(combined_bytes), + "combined_compressed_bytes": combined_compressed_bytes, + } + + +def _select_codebook_focus_tensor( + diagnostics: list[dict[str, float | int | str]], + focus: str, +) -> dict[str, float | int | str] | None: + if not diagnostics: + return None + if focus: + for item in diagnostics: + if item["name"] == focus: + return item + for item in diagnostics: + if focus in str(item["name"]): + return item + return max(diagnostics, key=lambda item: float(item["scale_compressed_bytes"])) + + +def _log_codebook_diagnostics(h: Hyperparameters, quant_stats: dict[str, object]) -> None: + diagnostics = list(quant_stats.get("diagnostics", [])) + if not diagnostics: + return + topk = max(int(h.codebook_entropy_summary_topk), 0) + total_idx_compressed = sum(int(item["idx_compressed_bytes"]) for item in diagnostics) + total_scale_compressed = sum(int(item["scale_compressed_bytes"]) for item in diagnostics) + total_combined_compressed = sum(int(item["combined_compressed_bytes"]) for item in diagnostics) + log( + f"codebook:entropy_summary tensors:{len(diagnostics)} " + f"idx_compressed_bytes:{total_idx_compressed} " + f"scale_compressed_bytes:{total_scale_compressed} " + f"combined_compressed_bytes:{total_combined_compressed}" + ) + if topk > 0: + ranked = sorted( + diagnostics, + key=lambda item: (float(item["scale_compressed_bytes"]), float(item["combined_compressed_bytes"])), + reverse=True, + )[:topk] + for item in ranked: + log( + "codebook:tensor " + f"name:{item['name']} " + f"idx_raw:{item['idx_raw_bytes']} idx_zip:{item['idx_compressed_bytes']} " + f"scale_raw:{item['scale_raw_bytes']} scale_zip:{item['scale_compressed_bytes']} " + f"combined_zip:{item['combined_compressed_bytes']} " + f"idx_H:{float(item['idx_entropy_bits']):.3f}bits " + f"scale_H:{float(item['scale_value_entropy_bits']):.3f}bits " + f"scale_dH:{float(item['scale_delta_entropy_bits']):.3f}bits " + f"scale_store:{float(item['scale_storage_bits']):.1f}bits " + f"scale_rel_delta:{float(item['scale_relative_delta']):.5f} " + f"scale_corr:{float(item['scale_adjacent_corr']):.5f}" + ) + focus_item = _select_codebook_focus_tensor(diagnostics, h.codebook_entropy_focus) + if focus_item is None: + return + log( + "codebook:focus " + f"name:{focus_item['name']} " + f"idx_unique:{focus_item['idx_unique']} " + f"idx_H:{float(focus_item['idx_entropy_bits']):.4f}bits " + f"idx_byte_H:{float(focus_item['idx_byte_entropy_bits']):.4f}bits " + f"idx_raw:{focus_item['idx_raw_bytes']} " + f"idx_entropy_bound:{float(focus_item['idx_entropy_bound_bytes']):.1f} " + f"idx_zip:{focus_item['idx_compressed_bytes']}" + ) + log( + "codebook:focus " + f"name:{focus_item['name']} " + f"scale_unique:{focus_item['scale_unique']} " + f"scale_H:{float(focus_item['scale_value_entropy_bits']):.4f}bits " + f"scale_dH:{float(focus_item['scale_delta_entropy_bits']):.4f}bits " + f"scale_byte_H:{float(focus_item['scale_byte_entropy_bits']):.4f}bits " + f"scale_store:{float(focus_item['scale_storage_bits']):.1f}bits " + f"scale_raw:{focus_item['scale_raw_bytes']} " + f"scale_entropy_bound:{float(focus_item['scale_entropy_bound_bytes']):.1f} " + f"scale_zip:{focus_item['scale_compressed_bytes']}" + ) + log( + "codebook:focus " + f"name:{focus_item['name']} " + f"combined_zip:{focus_item['combined_compressed_bytes']} " + f"scale_delta_mean_abs:{float(focus_item['scale_delta_mean_abs']):.6f} " + f"scale_rel_delta:{float(focus_item['scale_relative_delta']):.6f} " + f"scale_adjacent_corr:{float(focus_item['scale_adjacent_corr']):.6f}" + ) + + +def serialize(h: Hyperparameters, base_model: torch.nn.Module, code: str) -> int: + quantizer = CodebookQuantizer(h, base_model) + device = next(base_model.parameters()).device + code_bytes = len(code.encode("utf-8")) + bytes_total = code_bytes + if h.is_main_process: + torch.save(base_model.state_dict(), h.model_path) + model_bytes = os.path.getsize(h.model_path) + log(f"Serialized model: {model_bytes} bytes") + log(f"Code size: {code_bytes} bytes") + + hessians: dict[str, Tensor] = {} + if quantizer.target_names: + if h.is_main_process: + log("codebook:collecting calibration hessians...") + t0 = time.perf_counter() + calib_loader = DistributedTokenLoader(h.train_files, h.rank, h.world_size, device) + hessians = collect_hessians( + base_model, + calib_loader, + h, + device, + quantizer.target_names, + n_calibration_batches=h.codebook_calibration_batches, + ) + if h.is_main_process: + log(f"codebook:collected {len(hessians)} hessians in {time.perf_counter() - t0:.1f}s") + + if h.is_main_process: + t0 = time.perf_counter() + quantizer.fit(base_model, hessians) + log(f"codebook:fit_time:{time.perf_counter() - t0:.1f}s") + quant_result, quant_meta, quant_stats = quantizer.build_export(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta, "s": quant_stats}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = _compress(quant_raw, h.compressor) + quant_file_bytes = len(quant_blob) + bytes_total = quant_file_bytes + code_bytes + with open(h.quantized_model_path, "wb") as f: + f.write(quant_blob) + summary = quant_stats["summary"] + log( + f"Serialized model codebook+{h.compressor}: {quant_file_bytes} bytes " + f"(payload_before_torchsave:{summary['payload_bytes_before_torchsave']} bytes)" + ) + log( + f"Codebook coverage:{summary['coverage']:.4%} target_bpw:{summary['target_bpw']:.4f} " + f"effective_payload_bpw_all_weights:{summary['effective_payload_bpw_all_weights']:.4f}" + ) + log( + f"Codebook scale_codec:{'fp16' if h.codebook_scale_bits >= 16 else 'log'} " + f"scale_bits:{h.codebook_scale_bits} " + f"entropy_weight:{h.codebook_assignment_entropy_weight:.4f}" + ) + log( + f"Codebook outliers frac:{h.codebook_outlier_frac:.6f} " + f"max_count:{h.codebook_outlier_max_count} " + f"outlier_weights:{summary['outlier_weights']} " + f"outlier_bytes:{summary['outlier_payload_bytes']}" + ) + log( + f"Fallback payloads int8_weights:{summary['int8_fallback_weights']} " + f"int8_bytes:{summary['int8_fallback_payload_bytes']} " + f"passthrough_weights:{summary['passthrough_weights']} " + f"passthrough_bytes:{summary['passthrough_payload_bytes']}" + ) + _log_codebook_diagnostics(h, quant_stats) + log(f"Total submission size codebook+{h.compressor}: {bytes_total} bytes") + return bytes_total + + +def deserialize(h: Hyperparameters, device: torch.device) -> GPT: + eval_model = GPT(h).to(device).bfloat16() + restore_fp32_params(eval_model) + + template_sd = eval_model.state_dict() + + with open(h.quantized_model_path, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(_decompress(quant_blob_disk, h.compressor)), + map_location="cpu", + ) + deq_state = dequantize_state_dict_codebook(quant_state["w"], quant_state["m"], template_sd) + eval_model.load_state_dict(deq_state, strict=True) + + return eval_model + + +# ---------------------------------------- +# Evaluation +# ---------------------------------------- + +def _loss_bpb(loss_sum, token_count, byte_count) -> tuple[float, float]: + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + return val_loss, val_bpb + + +def eval_val( + h: Hyperparameters, + device: torch.device, + val_data: ValidationData, + model: nn.Module +) -> tuple[float, float]: + seq_len = h.eval_seq_len + local_batch_tokens = h.val_batch_tokens // (h.world_size * h.grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={h.val_batch_tokens}, WORLD_SIZE={h.world_size}, " + f"GRAD_ACCUM_STEPS={h.grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_data.val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * h.rank) // h.world_size + seq_end = (total_seqs * (h.rank + 1)) // h.world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + # Avoid leaving inference-mode tensors in module caches that training later reuses. + with torch.no_grad(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_data.val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = val_data.base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (val_data.has_leading_space_lut[tgt_ids] & ~val_data.is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + model.train() + return _loss_bpb(val_loss_sum, val_token_count, val_byte_count) + + +def eval_val_sliding( + h: Hyperparameters, + device: torch.device, + val_data: ValidationData, + base_model: nn.Module, + batch_seqs: int = 32 +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + base_model.eval() + logits_fn = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + + seq_len = h.eval_seq_len + context_size = seq_len - h.eval_stride + total_tokens = val_data.val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, h.eval_stride) + if ws + context_size < total_tokens] + + total_windows = len(window_starts) + my_s = (total_windows * h.rank) // h.world_size + my_e = (total_windows * (h.rank + 1)) // h.world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Avoid leaving inference-mode tensors in module caches that training later reuses. + with torch.no_grad(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + we = min(ws + seq_len, total_tokens) + wlen = we - ws + wlens.append(wlen) + chunk = val_data.val_tokens[ws:we + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = logits_fn(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else context_size + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = val_data.base_bytes_lut[tgt].to(torch.float64) + tb += (val_data.has_leading_space_lut[tgt] & ~val_data.is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + base_model.train() + return _loss_bpb(loss_sum, token_count, byte_count) + + +def timed_eval(label: str, fn, *args, **kwargs) -> tuple[float, float]: + torch.cuda.synchronize() + t0 = time.perf_counter() + val_loss, val_bpb = fn(*args, **kwargs) + torch.cuda.synchronize() + elapsed_ms = 1000.0 * (time.perf_counter() - t0) + log(f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms") + return val_loss, val_bpb + + +def run_evals( + h: Hyperparameters, + device: torch.device, + val_data: ValidationData, + eval_model: torch.nn.Module +): + compiled_model = torch.compile(eval_model, dynamic=False, fullgraph=True) + timed_eval("final_codebook_roundtrip", eval_val, h, device, val_data, compiled_model) + if h.sliding_window_enabled: + timed_eval("final_codebook_sliding_window", eval_val_sliding, h, device, val_data, eval_model) + +# ----------------------------- +# Training +# ----------------------------- + +def train_model(h: Hyperparameters, device: torch.device, val_data: ValidationData) -> None: + # Set up model + base_model = GPT(h).to(device).bfloat16() + restore_fp32_params(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + if h.distributed: + model = DDP(compiled_model, device_ids=[h.local_rank], broadcast_buffers=False) + else: + model = compiled_model + log(f"model_params:{sum(p.numel() for p in base_model.parameters())}") + + # Set up optimizer and load train data + optimizers = Optimizers(h, base_model) + train_loader = DistributedTokenLoader( h.train_files, h.rank, h.world_size, device) + soft_snap = FixedCodebookSoftSnap(h, base_model) if h.codebook_soft_snap_enabled else None + codebook_penalty = FixedCodebookPenalty(h, base_model) if h.codebook_penalty_enabled else None + + # Helper functions for training + max_wallclock_ms = 1000.0 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None + if max_wallclock_ms is not None and h.codebook_reserve_seconds > 0: + max_wallclock_ms = max(max_wallclock_ms - h.codebook_reserve_seconds * 1000.0, 0.0) + log(f"codebook:reserving {h.codebook_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms") + + def training_frac(step: int, elapsed_ms: float) -> float: + """Fraction of training completed (0 to 1), using step or wallclock.""" + if max_wallclock_ms is None: + return step / max(h.iterations, 1) + return elapsed_ms / max(max_wallclock_ms, 1e-9) + + def lr_mul(frac: float) -> float: + if h.warmdown_frac <= 0: + return 1.0 + if frac >= 1.0 - h.warmdown_frac: + return max((1.0 - frac) / h.warmdown_frac, h.min_lr) + return 1.0 + + def step_fn(step, lr_scale, frac): + optimizers.zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(h.grad_accum_steps): + if h.distributed: + model.require_backward_grad_sync = micro_step == h.grad_accum_steps - 1 + x, y = train_loader.next_batch(h.train_batch_tokens, h.train_seq_len, h.grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss / h.grad_accum_steps).backward() + train_loss /= h.grad_accum_steps + + momentum_frac = min(step / h.muon_momentum_warmup_steps, 1.0) if h.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - momentum_frac) * h.muon_momentum_warmup_start + momentum_frac * h.muon_momentum + for group in optimizers.optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * lr_scale + + codebook_penalty_loss = None + if codebook_penalty is not None: + codebook_penalty_loss = codebook_penalty.penalty(step, frac) + if codebook_penalty_loss is not None: + codebook_penalty_loss.backward() + + if h.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), h.grad_clip_norm) + + optimizers.step() + if soft_snap is not None: + soft_snap.apply(step + 1) + return train_loss, codebook_penalty_loss + + # Model warmup + if h.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0, 0.0) + if warmup_step <= 5 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == h.warmup_steps: + log(f"warmup_step: {warmup_step + 1}/{h.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + optimizers.zero_grad_all() + if h.distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader( + h.train_files, h.rank, h.world_size, device) + + # Training loop + use_ema = h.ema_decay < 1.0 + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} if use_ema else None + ema_decay = h.ema_decay + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == h.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (h.val_loss_every > 0 and step % h.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val(h, device, val_data, model) + log(f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}") + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < h.iterations: + log( + f"stopping_early: wallclock_cap train_time: {training_time_ms:.0f}ms " + f"step: {step}/{h.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + frac = training_frac(step, elapsed_ms) + if soft_snap is not None and soft_snap.should_activate(frac): + soft_snap.activate(step, frac) + if codebook_penalty is not None and codebook_penalty.should_activate(frac): + codebook_penalty.activate(step, frac) + scale = lr_mul(frac) + train_loss, codebook_penalty_loss = step_fn(step, scale, frac) + + if use_ema: + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + should_log_train = ( + h.train_log_every > 0 + and (step <= 5 or step % h.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + tok_per_sec = step * h.train_batch_tokens / (approx_training_time_ms / 1000.0) + train_loss_value = train_loss.item() + total_loss_value = train_loss_value + codebook_penalty_suffix = "" + if codebook_penalty_loss is not None: + codebook_penalty_loss_value = codebook_penalty_loss.item() + total_loss_value += codebook_penalty_loss_value + codebook_penalty_suffix = f" codebook_penalty_loss: {codebook_penalty_loss_value:.5f}" + log( + f"{step}/{h.iterations} train_loss: {train_loss_value:.4f} total_loss: {total_loss_value:.4f} " + f"train_time: {approx_training_time_ms / 60000:.1f}m tok/s: {tok_per_sec:.0f}" + f"{codebook_penalty_suffix}" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if h.distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + if use_ema: + log("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + + return base_model, compiled_model + + +def train_and_eval(h: Hyperparameters, device: torch.device) -> None: + random.seed(h.seed) + np.random.seed(h.seed) + torch.manual_seed(h.seed) + torch.cuda.manual_seed_all(h.seed) + + val_data = ValidationData(h, device) + log(f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob('fineweb_train_*.bin')))}") + log(f"val_tokens: {val_data.val_tokens.numel() - 1}") + + base_model, compiled_model = train_model(h, device, val_data) + timed_eval("pre-codebook post-ema" if h.ema_decay < 1.0 else "pre-codebook", eval_val, h, device, val_data, compiled_model) + + serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) + if h.distributed: + dist.barrier() + eval_model = deserialize(h, device) + + run_evals(h, device, val_data, eval_model) + + +def main(): + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision("high") + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + torch._dynamo.config.optimize_ddp = False + + h = Hyperparameters() + set_logging_hparams(h) + if h.is_main_process: + os.makedirs("logs", exist_ok=True) + log(100 * "=", console=False) + log("Hyperparameters:", console=True) + for k, v in sorted(vars(type(h)).items()): + if not k.startswith("_"): + log(f" {k}: {v}", console=True) + log(Path(__file__).read_text(encoding="utf-8"), console=False) + log("=" * 100, console=False) + log(f"Running Python {sys.version}", console=False) + log(f"Running PyTorch {torch.__version__}", console=False) + log( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log("=" * 100, console=False) + + train_and_eval(h, device) + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.11.0+cu128 +Tue Apr 7 05:16:13 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 46C P0 128W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 39C P0 123W / 700W | 1505MiB / 81559MiB | 2% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 37C P0 122W / 700W | 1505MiB / 81559MiB | 2% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 43C P0 123W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 45C P0 124W / 700W | 1505MiB / 81559MiB | 1% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 38C P0 121W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 44C P0 125W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 37C P0 120W / 700W | 1505MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +train_shards: 80 +val_tokens: 44847104 +model_params:40173675 +codebook:reserving 30s, effective=570000ms +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +0/20000 val_loss: 8.3157 val_bpb: 3.5690 +1/20000 train_loss: 8.3170 total_loss: 8.3170 train_time: 0.0m tok/s: 7140866 +2/20000 train_loss: 12.3095 total_loss: 12.3095 train_time: 0.0m tok/s: 7107727 +3/20000 train_loss: 10.8235 total_loss: 10.8235 train_time: 0.0m tok/s: 7040995 +4/20000 train_loss: 9.0787 total_loss: 9.0787 train_time: 0.0m tok/s: 7008260 +5/20000 train_loss: 7.8296 total_loss: 7.8296 train_time: 0.0m tok/s: 6990702 +500/20000 train_loss: 3.0397 total_loss: 3.0397 train_time: 1.0m tok/s: 6827335 +1000/20000 train_loss: 2.8887 total_loss: 2.8887 train_time: 1.9m tok/s: 6827473 +1500/20000 train_loss: 2.8805 total_loss: 2.8805 train_time: 2.9m tok/s: 6830836 +2000/20000 train_loss: 2.7780 total_loss: 2.7780 train_time: 3.8m tok/s: 6828954 +2500/20000 train_loss: 2.8103 total_loss: 2.8103 train_time: 4.8m tok/s: 6830534 +3000/20000 train_loss: 2.6904 total_loss: 2.6904 train_time: 5.8m tok/s: 6831747 +3500/20000 train_loss: 2.6651 total_loss: 2.6651 train_time: 6.7m tok/s: 6833112 +4000/20000 train_loss: 2.6505 total_loss: 2.6505 train_time: 7.7m tok/s: 6834654 +4000/20000 val_loss: 2.6429 val_bpb: 1.1343 +codebook_penalty:enabled step:4460/20000 start_frac:0.900 frac:0.900 weight:4.00000 mode:fast_direction_target refresh_every:64 target_tensors:78 target_weights:37486592 +4500/20000 train_loss: 2.6105 total_loss: 2.8348 train_time: 8.7m tok/s: 6784026 codebook_penalty_loss: 0.22425 +4866/20000 val_loss: 2.5752 val_bpb: 1.1053 +stopping_early: wallclock_cap train_time: 570017ms step: 4866/20000 +peak memory allocated: 30447 MiB reserved: 31148 MiB +ema:applying EMA weights +pre-codebook post-ema val_loss:2.57272514 val_bpb:1.10417097 eval_time:1976ms +Serialized model: 155502379 bytes +Code size: 121768 bytes +codebook:collecting calibration hessians... +codebook:collected 78 hessians in 9.7s +codebook:fit target_tensors:78 target_weights:37486592 coverage:93.3113% rel_mse:3.437927e-02 target_bpw:4.0000 +codebook:fit_time:18.0s +Serialized model codebook+brotli: 15738195 bytes (payload_before_torchsave:17625458 bytes) +Codebook coverage:93.3113% target_bpw:3.1705 effective_payload_bpw_all_weights:3.5099 +Codebook scale_codec:log scale_bits:8 entropy_weight:0.0000 +Codebook outliers frac:0.000000 max_count:2048 outlier_weights:159744 outlier_bytes:798876 +Fallback payloads int8_weights:2621440 int8_bytes:2637824 passthrough_weights:65643 passthrough_bytes:131286 +codebook:entropy_summary tensors:78 idx_compressed_bytes:9371509 scale_compressed_bytes:3582138 combined_compressed_bytes:12941632 +codebook:tensor name:blocks.8.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:109882 combined_zip:371753 idx_H:15.420bits scale_H:10.827bits scale_dH:11.727bits scale_store:8.0bits scale_rel_delta:0.33316 scale_corr:0.00179 +codebook:tensor name:blocks.9.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:109390 combined_zip:371362 idx_H:15.450bits scale_H:10.810bits scale_dH:11.714bits scale_store:8.0bits scale_rel_delta:0.32963 scale_corr:-0.00022 +codebook:tensor name:blocks.0.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:109312 combined_zip:371244 idx_H:15.459bits scale_H:10.759bits scale_dH:11.638bits scale_store:8.0bits scale_rel_delta:0.31745 scale_corr:0.00093 +codebook:tensor name:blocks.1.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:108854 combined_zip:370671 idx_H:15.442bits scale_H:10.759bits scale_dH:11.645bits scale_store:8.0bits scale_rel_delta:0.31824 scale_corr:-0.00121 +codebook:tensor name:blocks.6.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:108439 combined_zip:370244 idx_H:15.321bits scale_H:10.901bits scale_dH:11.807bits scale_store:8.0bits scale_rel_delta:0.34874 scale_corr:0.00681 +codebook:tensor name:blocks.4.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:108035 combined_zip:369821 idx_H:15.407bits scale_H:10.810bits scale_dH:11.705bits scale_store:8.0bits scale_rel_delta:0.32901 scale_corr:0.00011 +codebook:tensor name:blocks.2.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:107838 combined_zip:369612 idx_H:15.431bits scale_H:10.773bits scale_dH:11.660bits scale_store:8.0bits scale_rel_delta:0.32134 scale_corr:0.00023 +codebook:tensor name:blocks.3.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:107796 combined_zip:369564 idx_H:15.428bits scale_H:10.790bits scale_dH:11.680bits scale_store:8.0bits scale_rel_delta:0.32589 scale_corr:-0.00766 +codebook:tensor name:blocks.11.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:105496 combined_zip:367415 idx_H:15.456bits scale_H:10.758bits scale_dH:11.650bits scale_store:8.0bits scale_rel_delta:0.31935 scale_corr:-0.00415 +codebook:tensor name:blocks.7.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:105264 combined_zip:367035 idx_H:15.367bits scale_H:10.857bits scale_dH:11.764bits scale_store:8.0bits scale_rel_delta:0.33899 scale_corr:0.00135 +codebook:tensor name:blocks.10.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:104769 combined_zip:366723 idx_H:15.453bits scale_H:10.791bits scale_dH:11.692bits scale_store:8.0bits scale_rel_delta:0.32666 scale_corr:-0.00664 +codebook:tensor name:blocks.5.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:102130 combined_zip:363863 idx_H:15.364bits scale_H:10.876bits scale_dH:11.783bits scale_store:8.0bits scale_rel_delta:0.34438 scale_corr:-0.00499 +codebook:focus name:blocks.8.mlp.proj.weight idx_unique:52159 idx_H:15.4201bits idx_byte_H:7.9815bits idx_raw:262144 idx_entropy_bound:252642.6 idx_zip:262154 +codebook:focus name:blocks.8.mlp.proj.weight scale_unique:3112 scale_H:10.8270bits scale_dH:11.7274bits scale_byte_H:6.6825bits scale_store:8.0bits scale_raw:131072 scale_entropy_bound:177389.2 scale_zip:109882 +codebook:focus name:blocks.8.mlp.proj.weight combined_zip:371753 scale_delta_mean_abs:0.017138 scale_rel_delta:0.333161 scale_adjacent_corr:0.001785 +Total submission size codebook+brotli: 15859963 bytes +final_codebook_roundtrip val_loss:2.85256205 val_bpb:1.22427233 eval_time:7296ms +final_codebook_sliding_window val_loss:2.81219097 val_bpb:1.20694573 eval_time:76110ms diff --git a/records/track_non_record_16mb/2026-04-07_Codebooks/vq_codebook_13_42.log b/records/track_non_record_16mb/2026-04-07_Codebooks/vq_codebook_13_42.log new file mode 100644 index 0000000000..b344afe41b --- /dev/null +++ b/records/track_non_record_16mb/2026-04-07_Codebooks/vq_codebook_13_42.log @@ -0,0 +1,2998 @@ +==================================================================================================== +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + codebook_assignment_entropy_weight: 0.0 + codebook_block_dim: 8 + codebook_calibration_batches: 64 + codebook_entropy_focus: + codebook_entropy_summary_topk: 12 + codebook_hessian_damp: 0.01 + codebook_lattice_scale: 1.03 + codebook_outlier_frac: 0.0 + codebook_outlier_max_count: 2048 + codebook_penalty_enabled: True + codebook_penalty_refresh_every: 64 + codebook_penalty_start_frac: 0.9 + codebook_penalty_weight: 4.0 + codebook_reserve_seconds: 30.0 + codebook_scale_bits: 8 + codebook_soft_snap_alpha: 0.02 + codebook_soft_snap_enabled: False + codebook_soft_snap_start_frac: 0.9 + codebook_soft_snap_update_every: 1 + codebook_use_hadamard: True + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp4096 + distributed: True + ema_decay: 0.99 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + eval_seq_len: 2048 + eval_stride: 64 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/vq_codebook_13_2048_42.txt + logit_softcap: 30.0 + matrix_lr: 0.02 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_wd: 0.085 + num_heads: 8 + num_kv_heads: 4 + num_layers: 13 + qk_gain_init: 4.0 + quantized_model_path: final_model.codebook.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: vq_codebook_13_2048_42 + scalar_lr: 0.02 + seed: 42 + skip_gates_enabled: True + sliding_window_enabled: True + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_4096_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp4096/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + val_batch_tokens: 524288 + val_files: ./data/datasets/fineweb10B_sp4096/fineweb_val_*.bin + val_loss_every: 4000 + ve_dim: 128 + ve_enabled: True + ve_layers: 9,10 + vocab_size: 4096 + warmdown_frac: 0.667 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 13 +import copy +import glob +import io +import lzma +import math +import os +from pathlib import Path +import random +import subprocess +import sys +import time +import uuid + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +from torch import Tensor, nn + +try: + from flash_attn_interface import flash_attn_func as _flash_attn_3_func +except ImportError: + _flash_attn_3_func = None + +def _use_flash_attn() -> bool: + if _flash_attn_3_func is None: + return False + v = os.environ.get("USE_FLASH_ATTN", "1").strip().lower() + return v not in ("0", "false", "no", "off") + +# ---------------------------------------- +# Hyperparameters +# ---------------------------------------- + +class Hyperparameters(): + # Experiment settings + data_dir = os.environ.get('DATA_DIR', './data/') + seed = int(os.environ.get('SEED', 1337)) + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + + # Training length + iterations = int(os.environ.get('ITERATIONS', 20000)) + warmdown_frac = float(os.environ.get('WARMDOWN_FRAC', 0.667)) + warmup_steps = int(os.environ.get('WARMUP_STEPS', 20)) + train_batch_tokens = int(os.environ.get('TRAIN_BATCH_TOKENS', 2048 * 48 * 8)) + train_seq_len = int(os.environ.get('TRAIN_SEQ_LEN', 2048)) + eval_seq_len = int(os.environ.get('EVAL_SEQ_LEN', 2048)) + max_wallclock_seconds = float(os.environ.get('MAX_WALLCLOCK_SECONDS', 600.0)) + train_log_every = int(os.environ.get('TRAIN_LOG_EVERY', 500)) + + # Validation/Evals + val_batch_tokens = int(os.environ.get('VAL_BATCH_TOKENS', 2048 * 32 * 8)) + val_loss_every = int(os.environ.get('VAL_LOSS_EVERY', 4000)) + sliding_window_enabled = bool(int(os.environ.get('SLIDING_WINDOW_ENABLED', '1'))) + + # Model architecture + vocab_size = int(os.environ.get('VOCAB_SIZE', 4096)) + num_layers = int(os.environ.get('NUM_LAYERS', 11)) + xsa_last_n = int(os.environ.get('XSA_LAST_N', 11)) + num_kv_heads = int(os.environ.get('NUM_KV_HEADS', 4)) + model_dim = int(os.environ.get('MODEL_DIM', 512)) + embedding_dim = int(os.environ.get('EMBEDDING_DIM', 512)) + num_heads = int(os.environ.get('NUM_HEADS', 8)) + mlp_mult = float(os.environ.get('MLP_MULT', 4.0)) + skip_gates_enabled = bool(int(os.environ.get('SKIP_GATES_ENABLED', '1'))) + tie_embeddings = bool(int(os.environ.get('TIE_EMBEDDINGS', '1'))) + logit_softcap = float(os.environ.get('LOGIT_SOFTCAP', 30.0)) + rope_base = float(os.environ.get('ROPE_BASE', 10000.0)) + rope_dims = int(os.environ.get('ROPE_DIMS', 16)) + rope_train_seq_len = int(os.environ.get('ROPE_TRAIN_SEQ_LEN', 2048)) + ln_scale = bool(int(os.environ.get('LN_SCALE', '1'))) + ve_enabled = bool(int(os.environ.get('VE_ENABLED', '1'))) + ve_dim = int(os.environ.get('VE_DIM', 128)) + ve_layers = os.environ.get('VE_LAYERS', '9,10') + qk_gain_init = float(os.environ.get('QK_GAIN_INIT', 4.0)) + + # Optimizer + min_lr = float(os.environ.get('MIN_LR', 0.0)) + embed_lr = float(os.environ.get('EMBED_LR', 0.6)) + head_lr = float(os.environ.get('HEAD_LR', 0.008)) + tied_embed_lr = float(os.environ.get('TIED_EMBED_LR', 0.03)) + tied_embed_init_std = float(os.environ.get('TIED_EMBED_INIT_STD', 0.005)) + matrix_lr = float(os.environ.get('MATRIX_LR', 0.02)) + scalar_lr = float(os.environ.get('SCALAR_LR', 0.02)) + muon_momentum = float(os.environ.get('MUON_MOMENTUM', 0.99)) + muon_backend_steps = int(os.environ.get('MUON_BACKEND_STEPS', 5)) + muon_momentum_warmup_start = float(os.environ.get('MUON_MOMENTUM_WARMUP_START', 0.92)) + muon_momentum_warmup_steps = int(os.environ.get('MUON_MOMENTUM_WARMUP_STEPS', 1500)) + beta1 = float(os.environ.get('BETA1', 0.9)) + beta2 = float(os.environ.get('BETA2', 0.95)) + adam_eps = float(os.environ.get('ADAM_EPS', 1e-8)) + grad_clip_norm = float(os.environ.get('GRAD_CLIP_NORM', 0.3)) + eval_stride = int(os.environ.get('EVAL_STRIDE', 64)) + muon_beta2 = float(os.environ.get('MUON_BETA2', 0.95)) + adam_wd = float(os.environ.get('ADAM_WD', 0.02)) + muon_wd = float(os.environ.get('MUON_WD', 0.085)) + embed_wd = float(os.environ.get('EMBED_WD', 0.085)) + ema_decay = float(os.environ.get('EMA_DECAY', 1.0)) + + # Compression + compressor = os.environ.get('COMPRESSOR', 'brotli') #(lzma or brotli) + codebook_block_dim = int(os.environ.get("CODEBOOK_BLOCK_DIM", 8)) + codebook_use_hadamard = bool(int(os.environ.get("CODEBOOK_USE_HADAMARD", "1"))) + codebook_calibration_batches = int(os.environ.get("CODEBOOK_CALIBRATION_BATCHES", 64)) + codebook_reserve_seconds = float(os.environ.get("CODEBOOK_RESERVE_SECONDS", 10.0)) + codebook_hessian_damp = float(os.environ.get("CODEBOOK_HESSIAN_DAMP", 0.01)) + codebook_lattice_scale = float(os.environ.get("CODEBOOK_LATTICE_SCALE", 1.03)) + codebook_scale_bits = int(os.environ.get("CODEBOOK_SCALE_BITS", 8)) + codebook_assignment_entropy_weight = float(os.environ.get("CODEBOOK_ASSIGNMENT_ENTROPY_WEIGHT", 0.01)) + codebook_soft_snap_enabled = bool(int(os.environ.get("CODEBOOK_SOFT_SNAP_ENABLED", "0"))) + codebook_soft_snap_start_frac = float(os.environ.get("CODEBOOK_SOFT_SNAP_START_FRAC", 0.9)) + codebook_soft_snap_alpha = float(os.environ.get("CODEBOOK_SOFT_SNAP_ALPHA", 0.02)) + codebook_soft_snap_update_every = int(os.environ.get("CODEBOOK_SOFT_SNAP_UPDATE_EVERY", 1)) + codebook_penalty_enabled = bool( + int(os.environ.get("CODEBOOK_PENALTY_ENABLED", os.environ.get("CODEBOOK_STALE_PENALTY_ENABLED", "0"))) + ) + codebook_penalty_start_frac = float( + os.environ.get("CODEBOOK_PENALTY_START_FRAC", os.environ.get("CODEBOOK_STALE_PENALTY_START_FRAC", 0.75)) + ) + codebook_penalty_weight = float( + os.environ.get("CODEBOOK_PENALTY_WEIGHT", os.environ.get("CODEBOOK_STALE_PENALTY_WEIGHT", 0.01)) + ) + codebook_penalty_refresh_every = int( + os.environ.get("CODEBOOK_PENALTY_REFRESH_EVERY", os.environ.get("CODEBOOK_STALE_PENALTY_REFRESH_EVERY", 16)) + ) + codebook_outlier_frac = float(os.environ.get("CODEBOOK_OUTLIER_FRAC", 0.0)) + codebook_outlier_max_count = int(os.environ.get("CODEBOOK_OUTLIER_MAX_COUNT", 0)) + codebook_entropy_summary_topk = int(os.environ.get("CODEBOOK_ENTROPY_SUMMARY_TOPK", 12)) + codebook_entropy_focus = os.environ.get("CODEBOOK_ENTROPY_FOCUS", "").strip() + + # Distributed setup + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + is_main_process = rank == 0 + grad_accum_steps = 8 // world_size + + # Data paths + datasets_dir = os.path.join(data_dir, 'datasets', f'fineweb10B_sp{vocab_size}') + train_files = os.path.join(datasets_dir, 'fineweb_train_*.bin') + val_files = os.path.join(datasets_dir, 'fineweb_val_*.bin') + tokenizer_path = os.path.join(data_dir, 'tokenizers', f'fineweb_{vocab_size}_bpe.model') + + # Experiment files + logfile = f"logs/{run_id}.txt" + model_path = "final_model.pt" + quantized_model_path = "final_model.codebook.ptz" + +# ---------------------------------------- +# Global Logging Function +# ---------------------------------------- + +_logger_hparams = None + + +def set_logging_hparams(h: Hyperparameters) -> None: + global _logger_hparams + _logger_hparams = h + + +def log(msg, console: bool = True) -> None: + if _logger_hparams is None: + print(msg) + if _logger_hparams.is_main_process: + if console: + print(msg) + if _logger_hparams.logfile is not None: + with open(_logger_hparams.logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + +# ---------------------------------------- +# Data Loading +# ---------------------------------------- + +class ValidationData: + def __init__(self, h: Hyperparameters, device: torch.device): + if not h.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {h.tokenizer_path}") + self.sp = spm.SentencePieceProcessor(model_file=h.tokenizer_path) + if int(self.sp.vocab_size()) != h.vocab_size: + raise ValueError( + f"VOCAB_SIZE={h.vocab_size} does not match tokenizer vocab_size={int(self.sp.vocab_size())}" + ) + + self.val_tokens = load_validation_tokens(h.val_files, h.eval_seq_len) + self.base_bytes_lut, self.has_leading_space_lut, self.is_boundary_token_lut = ( + build_sentencepiece_luts(self.sp, h.vocab_size, device)) + + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + # The BPB calculation assumes "▁" is its own token so that leading-space bytes + # are counted correctly. See https://github.com/openai/parameter-golf/issues/897 + assert sp.piece_to_id("\u2581") != sp.unk_id(), \ + "Tokenizer must have '▁' (space) as its own token for correct BPB byte counting" + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" int: + key = str(file) + cached = _SHARD_NTOKENS_CACHE.get(key) + if cached is not None: + return cached + header = np.fromfile(file, dtype=" np.memmap: + key = str(file) + mm = _MMAP_CACHE.get(key) + if mm is not None: + return mm + n = _read_num_tokens(file) + mm = np.memmap(file, mode="r", dtype=" int: + if n <= 1: + return 1 + while True: + s = int(self._rng.integers(1, n)) + if math.gcd(s, n) == 1: + return s + + def _reset_cursor(self, si: int, seq_len: int) -> None: + nt = int(self._num_tokens[si]) + max_phase = min(seq_len - 1, max(0, nt - seq_len - 1)) + phase = int(self._rng.integers(max_phase + 1)) if max_phase > 0 else 0 + bc = (nt - 1 - phase) // seq_len + self._cursor_phase[si] = phase + self._cursor_block_count[si] = bc + self._cursor_next[si] = 0 + self._cursor_start[si] = int(self._rng.integers(bc)) if bc > 1 else 0 + self._cursor_stride[si] = self._pick_coprime_stride(bc) + self._cursor_init[si] = True + + def _ensure_cursor(self, si: int, seq_len: int) -> None: + if not self._cursor_init[si] or self._cursor_next[si] >= self._cursor_block_count[si]: + self._reset_cursor(si, seq_len) + + def _take_from_shard(self, si: int, seq_len: int, count: int, out: list[tuple[int, int]]) -> None: + rem = count + while rem > 0: + self._ensure_cursor(si, seq_len) + bc = int(self._cursor_block_count[si]) + ni = int(self._cursor_next[si]) + take = min(rem, bc - ni) + phase = int(self._cursor_phase[si]) + start = int(self._cursor_start[si]) + stride = int(self._cursor_stride[si]) + for j in range(take): + bi = (start + (ni + j) * stride) % bc + out.append((si, phase + bi * seq_len)) + self._cursor_next[si] = ni + take + rem -= take + + def _init_pipeline(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> None: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + num_seqs = local_tokens // seq_len + global_num_seqs = num_seqs * self.world_size + self._cfg = (local_tokens, seq_len, num_seqs, global_num_seqs) + bbc = (self._num_tokens - 1) // seq_len + eligible = bbc > 0 + self._eligible_shards = np.nonzero(eligible)[0].astype(np.int64) + self._base_block_counts = bbc[self._eligible_shards].astype(np.int64) + + def _sample_global_windows(self) -> list[tuple[int, int]]: + assert self._cfg is not None and self._eligible_shards is not None + _, seq_len, _, gns = self._cfg + ec = int(self._eligible_shards.size) + progress = min(self._batches_built / 1800.0, 1.0) + remaining = np.empty(ec, dtype=np.float64) + for i, si in enumerate(self._eligible_shards.tolist()): + if self._cursor_init[si]: + r = int(self._cursor_block_count[si]) - int(self._cursor_next[si]) + remaining[i] = float(max(r, 1)) + else: + remaining[i] = float(self._base_block_counts[i]) + alpha = 0.90 - 0.40 * progress + weights = np.power(remaining, alpha) + ws = float(weights.sum()) + if not np.isfinite(ws) or ws <= 0.0: + weights = np.ones(ec, dtype=np.float64) + ws = float(weights.sum()) + probs = weights / ws + low = min(max(8, self.world_size), ec, gns) + high = min(max(32, self.world_size * 8), ec, gns) + mix = max(1, min(int(round(low + progress * (high - low))), ec, gns)) + cp = self._rng.choice(ec, size=mix, replace=False, p=probs) + cs = self._eligible_shards[cp] + cpr = probs[cp].copy() + cpr /= cpr.sum() + counts = np.ones(mix, dtype=np.int64) + extra = gns - mix + if extra > 0: + counts += self._rng.multinomial(extra, cpr).astype(np.int64) + perm = self._rng.permutation(mix) + cs, counts = cs[perm], counts[perm] + buckets: list[list[tuple[int, int]]] = [] + for si, cnt in zip(cs.tolist(), counts.tolist()): + b: list[tuple[int, int]] = [] + self._take_from_shard(int(si), seq_len, int(cnt), b) + if b: + if len(b) > 1: + bp = self._rng.permutation(len(b)) + b = [b[int(k)] for k in bp.tolist()] + buckets.append(b) + windows: list[tuple[int, int]] = [] + active = [i for i, bk in enumerate(buckets) if bk] + while active: + order = self._rng.permutation(len(active)) + new_active: list[int] = [] + for oi in order.tolist(): + bi = active[oi] + if buckets[bi]: + windows.append(buckets[bi].pop()) + if buckets[bi]: + new_active.append(bi) + active = new_active + return windows + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + if self._cfg is None: + self._init_pipeline(global_tokens, seq_len, grad_accum_steps) + _, _, num_seqs, _ = self._cfg + gw = self._sample_global_windows() + local_w = gw[self.rank::self.world_size] + x = torch.empty((num_seqs, seq_len), dtype=torch.int64) + y = torch.empty((num_seqs, seq_len), dtype=torch.int64) + for slot, (si, pos) in enumerate(local_w): + mm = _get_shard_memmap(self.files[si]) + window = torch.as_tensor(np.array(mm[pos:pos + seq_len + 1], dtype=np.int64)) + x[slot] = window[:-1] + y[slot] = window[1:] + self._batches_built += 1 + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ---------------------------------------- +# Model Architecture +# ---------------------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, + rope_base: float, qk_gain_init: float, train_seq_len: int): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len) + self.use_xsa = False + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if _use_flash_attn(): + y = _flash_attn_3_func(q, k, v, causal=True) + else: + y = F.scaled_dot_product_attention(q, k, v, is_causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + + +class ValueEmbedding(nn.Module): + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + return self.proj(F.leaky_relu(self.fc(x), negative_slope=0.5).square()) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, + rope_base: float, qk_gain_init: float, train_seq_len: int, + layer_idx: int = 0, ln_scale: bool = False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + return x_out + + +class GPT(nn.Module): + def __init__(self, h: Hyperparameters): + super().__init__() + self._ve_target_dim = h.num_kv_heads * (h.model_dim // h.num_heads) + if h.logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {h.logit_softcap}") + self.tie_embeddings = h.tie_embeddings + self.tied_embed_init_std = h.tied_embed_init_std + self.logit_softcap = h.logit_softcap + self.tok_emb = nn.Embedding(h.vocab_size, h.embedding_dim) + if h.embedding_dim != h.model_dim: + self.embed_proj = CastedLinear(h.embedding_dim, h.model_dim, bias=False) + self.head_proj = CastedLinear(h.model_dim, h.embedding_dim, bias=False) + else: + self.embed_proj = None + self.head_proj = None + self.num_encoder_layers = h.num_layers // 2 + self.num_decoder_layers = h.num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, h.model_dim, dtype=torch.float32)) + self.skip_gates = nn.Parameter(torch.zeros(self.num_skip_weights, h.model_dim, dtype=torch.float32)) if h.skip_gates_enabled else None + self.blocks = nn.ModuleList([ + Block(h.model_dim, h.num_heads, h.num_kv_heads, h.mlp_mult, h.rope_base, + h.qk_gain_init, h.train_seq_len, layer_idx=i, ln_scale=h.ln_scale) + for i in range(h.num_layers) + ]) + if h.rope_dims > 0: + head_dim = h.model_dim // h.num_heads + for block in self.blocks: + block.attn.rope_dims = h.rope_dims + block.attn.rotary = Rotary(head_dim, base=h.rope_base, train_seq_len=h.train_seq_len, rope_dims=h.rope_dims) + self.ve_layer_indices = [int(x) for x in h.ve_layers.split(",") if x.strip()] if h.ve_enabled else [] + kv_dim = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(h.vocab_size, h.ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() + self.final_norm = RMSNorm() + self.lm_head = None if h.tie_embeddings else CastedLinear(h.embedding_dim, h.vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + if h.xsa_last_n > 0: + for i in range(max(0, h.num_layers - h.xsa_last_n), h.num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + if self.embed_proj is not None: + x = self.embed_proj(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + scaled_skip = self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[i].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + x = self.final_norm(x) + if self.head_proj is not None: + x = self.head_proj(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self.forward_logits(input_ids) + return F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1), reduction="mean") + +# ---------------------------------------- +# Optimization +# ---------------------------------------- + +@torch.compile +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +class Optimizers(): + def __init__(self, h: Hyperparameters, base_model: GPT): + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in + CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in + CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.skip_gates is not None and base_model.skip_gates.numel() > 0: + scalar_params.append(base_model.skip_gates) + + token_lr = h.tied_embed_lr if h.tie_embeddings else h.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + + self.optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.embed_wd, + fused=True, + ) + self.optimizer_muon = Muon( + matrix_params, + lr=h.matrix_lr, + momentum=h.muon_momentum, + backend_steps=h.muon_backend_steps, + weight_decay=h.muon_wd, + ) + for group in self.optimizer_muon.param_groups: + group["base_lr"] = h.matrix_lr + self.optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": h.scalar_lr, "base_lr": h.scalar_lr}], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.adam_wd, + fused=True, + ) + self.optimizers: list[torch.optim.Optimizer] = [self.optimizer_tok, self.optimizer_muon, self.optimizer_scalar] + if base_model.lm_head is not None: + self.optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": h.head_lr, "base_lr": h.head_lr}], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + fused=True, + ) + self.optimizers.insert(1, self.optimizer_head) + else: + self.optimizer_head = None + + def __iter__(self): + return iter(self.optimizers) + + def zero_grad_all(self) -> None: + for opt in self.optimizers: + opt.zero_grad(set_to_none=True) + + def step(self): + for opt in self.optimizers: + opt.step() + self.zero_grad_all() + +# ---------------------------------------- +# Quantization +# ---------------------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates,ve_layer_scales,ve_shared.scale", + ).split(",") + if pattern +) +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +CODEBOOK_TARGET_PATTERNS = ( + "attn.c_q.weight", + "attn.c_k.weight", + "attn.c_v.weight", + "attn.proj.weight", + "mlp.fc.weight", + "mlp.proj.weight", +) + + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def dequantize_int8_rows(q: Tensor, s: Tensor, *, device: torch.device, dtype: torch.dtype) -> Tensor: + qf = q.to(device=device, dtype=torch.float32) + sf = s.to(device=device, dtype=torch.float32) + if sf.ndim > 0: + return (qf * sf.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype) + return (qf * float(sf.item())).to(dtype=dtype) + + +def restore_fp32_params(model: nn.Module) -> None: + """After .bfloat16(), restore CastedLinear weights and control params to FP32.""" + for module in model.modules(): + if isinstance(module, CastedLinear): + module.float() + for name, param in model.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +def _is_power_of_two(n: int) -> bool: + return n > 0 and (n & (n - 1)) == 0 + + +def _validate_codebook_hparams(h: Hyperparameters) -> None: + if not _is_power_of_two(h.codebook_block_dim): + raise ValueError(f"CODEBOOK_BLOCK_DIM must be a power of 2, got {h.codebook_block_dim}") + if h.codebook_block_dim != 8: + raise ValueError(f"The E8P codebook expects CODEBOOK_BLOCK_DIM=8, got {h.codebook_block_dim}") + if h.codebook_calibration_batches < 0: + raise ValueError( + f"CODEBOOK_CALIBRATION_BATCHES must be non-negative, got {h.codebook_calibration_batches}" + ) + if h.codebook_reserve_seconds < 0: + raise ValueError( + f"CODEBOOK_RESERVE_SECONDS must be non-negative, got {h.codebook_reserve_seconds}" + ) + if h.codebook_hessian_damp < 0: + raise ValueError( + f"CODEBOOK_HESSIAN_DAMP must be non-negative, got {h.codebook_hessian_damp}" + ) + if h.codebook_lattice_scale <= 0: + raise ValueError(f"CODEBOOK_LATTICE_SCALE must be positive, got {h.codebook_lattice_scale}") + if h.codebook_scale_bits <= 0 or h.codebook_scale_bits > 16: + raise ValueError(f"CODEBOOK_SCALE_BITS must be in [1, 16], got {h.codebook_scale_bits}") + if h.codebook_assignment_entropy_weight < 0: + raise ValueError( + f"CODEBOOK_ASSIGNMENT_ENTROPY_WEIGHT must be non-negative, got {h.codebook_assignment_entropy_weight}" + ) + if not 0.0 <= h.codebook_soft_snap_start_frac <= 1.0: + raise ValueError( + f"CODEBOOK_SOFT_SNAP_START_FRAC must be in [0, 1], got {h.codebook_soft_snap_start_frac}" + ) + if not 0.0 <= h.codebook_soft_snap_alpha <= 1.0: + raise ValueError(f"CODEBOOK_SOFT_SNAP_ALPHA must be in [0, 1], got {h.codebook_soft_snap_alpha}") + if h.codebook_soft_snap_update_every <= 0: + raise ValueError( + f"CODEBOOK_SOFT_SNAP_UPDATE_EVERY must be positive, got {h.codebook_soft_snap_update_every}" + ) + if not 0.0 <= h.codebook_penalty_start_frac <= 1.0: + raise ValueError( + f"CODEBOOK_PENALTY_START_FRAC must be in [0, 1], got {h.codebook_penalty_start_frac}" + ) + if h.codebook_penalty_weight < 0: + raise ValueError( + f"CODEBOOK_PENALTY_WEIGHT must be non-negative, got {h.codebook_penalty_weight}" + ) + if h.codebook_penalty_refresh_every <= 0: + raise ValueError( + f"CODEBOOK_PENALTY_REFRESH_EVERY must be positive, got {h.codebook_penalty_refresh_every}" + ) + if not 0.0 <= h.codebook_outlier_frac <= 1.0: + raise ValueError(f"CODEBOOK_OUTLIER_FRAC must be in [0, 1], got {h.codebook_outlier_frac}") + if h.codebook_outlier_max_count < 0: + raise ValueError(f"CODEBOOK_OUTLIER_MAX_COUNT must be non-negative, got {h.codebook_outlier_max_count}") + + +def _blockify_weight(t: Tensor, block_dim: int) -> tuple[Tensor, tuple[int, int]]: + if t.ndim != 2: + raise ValueError(f"Codebook quantization expects a 2D tensor, got shape {tuple(t.shape)}") + if t.shape[1] % block_dim != 0: + raise ValueError( + f"Tensor shape {tuple(t.shape)} is not compatible with CODEBOOK_BLOCK_DIM={block_dim}" + ) + return t.contiguous().view(-1, block_dim), (int(t.shape[0]), int(t.shape[1])) + + +def _unblockify_weight(blocks: Tensor, original_shape: tuple[int, int]) -> Tensor: + return blocks.contiguous().view(original_shape) + + +def _should_codebook_quantize(name: str, t: Tensor, h: Hyperparameters) -> bool: + return ( + t.is_floating_point() + and t.ndim == 2 + and t.shape[1] % h.codebook_block_dim == 0 + and name.startswith("blocks.") + and any(name.endswith(pattern) for pattern in CODEBOOK_TARGET_PATTERNS) + ) +def _hadamard_sign_vector(name: str, block_dim: int, *, device: torch.device, dtype: torch.dtype) -> Tensor: + rng = random.Random(f"hadamard::{name}::{block_dim}") + return torch.tensor( + [1.0 if rng.getrandbits(1) else -1.0 for _ in range(block_dim)], + device=device, + dtype=dtype, + ) + + +def hadamard_transform(x: Tensor, scale: float = 1.0) -> Tensor: + original_shape = x.shape + dim = x.shape[-1] + padded_dim = 1 << (max(dim, 1) - 1).bit_length() + out = x.reshape(-1, dim) + if padded_dim != dim: + out = F.pad(out, (0, padded_dim - dim)) + h = 1 + while h < padded_dim: + out = out.view(-1, padded_dim // (2 * h), 2, h) + a = out[:, :, 0, :] + b = out[:, :, 1, :] + out = torch.stack((a + b, a - b), dim=2).reshape(-1, padded_dim) + h *= 2 + out = out[:, :dim] + return (out * scale).reshape(*original_shape) + + +def hadamard_rotate_blocks(blocks: Tensor, sign_vec: Tensor, *, enabled: bool = True) -> Tensor: + if not enabled: + return blocks + return hadamard_transform(blocks * sign_vec, scale=blocks.shape[-1] ** -0.5) + + +def hadamard_unrotate_blocks(blocks: Tensor, sign_vec: Tensor, *, enabled: bool = True) -> Tensor: + if not enabled: + return blocks + return hadamard_transform(blocks, scale=blocks.shape[-1] ** -0.5) * sign_vec + + +def get_norm12() -> Tensor: + return torch.tensor([ + [3, 1, 1, 1, 3, 3, 3, 3], + [1, 3, 1, 1, 3, 3, 3, 3], + [1, 1, 3, 1, 3, 3, 3, 3], + [1, 1, 1, 3, 3, 3, 3, 3], + [3, 3, 3, 1, 3, 3, 1, 1], + [3, 3, 3, 1, 3, 1, 3, 1], + [3, 3, 3, 1, 1, 3, 3, 1], + [3, 3, 3, 1, 3, 1, 1, 3], + [3, 3, 3, 1, 1, 3, 1, 3], + [3, 3, 3, 1, 1, 1, 3, 3], + [3, 3, 1, 3, 3, 3, 1, 1], + [3, 3, 1, 3, 3, 1, 3, 1], + [3, 3, 1, 3, 1, 3, 3, 1], + [3, 3, 1, 3, 3, 1, 1, 3], + [3, 3, 1, 3, 1, 3, 1, 3], + [3, 3, 1, 3, 1, 1, 3, 3], + [3, 1, 3, 3, 3, 3, 1, 1], + [3, 1, 3, 3, 3, 1, 3, 1], + [3, 1, 3, 3, 1, 3, 3, 1], + [3, 1, 3, 3, 3, 1, 1, 3], + [3, 1, 3, 3, 1, 3, 1, 3], + [1, 3, 3, 3, 1, 1, 3, 3], + [1, 3, 3, 3, 3, 3, 1, 1], + [1, 3, 3, 3, 3, 1, 3, 1], + [1, 3, 3, 3, 1, 3, 3, 1], + [1, 3, 3, 3, 3, 1, 1, 3], + [1, 3, 3, 3, 1, 3, 1, 3], + [1, 1, 3, 3, 1, 3, 3, 3], + [3, 3, 1, 1, 3, 3, 3, 1], + ], dtype=torch.float32) / 2 + + +def get_packed_abs_grid() -> Tensor: + intr = torch.arange(-4, 4) + d8 = torch.cartesian_prod(*[intr] * 8).float() + 0.5 + d8_even = d8.sum(dim=-1) % 2 == 0 + d8_small = d8.norm(dim=-1).square() <= 10 + d8_abs = torch.unique(d8[torch.where(d8_even & d8_small)[0]].abs(), dim=0) + codebook_abs = torch.cat([d8_abs, get_norm12()], dim=0) + codebook_abs = codebook_abs[:, [0, 2, 4, 6, 1, 3, 5, 7]] + codebook_abs[:, 7] *= (1 - 2 * (codebook_abs.sum(1) % 2)) + codebook_abs = (codebook_abs * 2 + 8).to(torch.int32) + acc = codebook_abs[:, 0] + for i in range(7): + acc = acc | (codebook_abs[:, i + 1] << ((i + 1) * 4)) + return acc + + +def get_abs_grid() -> Tensor: + intr = torch.arange(-4, 4) + d8 = torch.cartesian_prod(*[intr] * 8).float() + 0.5 + d8_even = d8.sum(dim=-1) % 2 == 0 + d8_small = d8.norm(dim=-1).square() <= 10 + d8_abs = torch.unique(d8[torch.where(d8_even & d8_small)[0]].abs(), dim=0) + return torch.cat([d8_abs, get_norm12()], dim=0) + + +def get_full_grid(packed_abs_grid: Tensor) -> tuple[Tensor, Tensor]: + codebook = torch.zeros(1 << 16, 8) + parity_idx: list[int] = [] + shuffle_map = [0, 4, 1, 5, 2, 6, 3, 7] + for code in range(1 << 16): + signs = code & 255 + abs_idx = code >> 8 + parity = 0 + for bit in range(8): + parity ^= (signs >> bit) & 1 + signs ^= parity + abs_code = packed_abs_grid[abs_idx].item() + for i, shuffled in enumerate(shuffle_map): + codebook[code, i] = (((abs_code >> (4 * shuffled)) & 15) - 8) * 0.5 + if (signs >> shuffled) & 1: + codebook[code, i] *= -1 + if parity: + codebook[code] -= 0.25 + parity_idx.append(code) + else: + codebook[code] += 0.25 + return codebook, torch.tensor(parity_idx, dtype=torch.long) + + +_E8P_PACKED_ABS = get_packed_abs_grid() +_E8P_GRID, _E8P_PARITY_IDX = get_full_grid(_E8P_PACKED_ABS) + + +class E8P12Codebook(nn.Module): + def __init__(self) -> None: + super().__init__() + self.codesz = 8 + self.register_buffer("grid", _E8P_GRID) + self.register_buffer("grid_norm", _E8P_GRID.norm(dim=-1).square()) + grid_part = _E8P_GRID[_E8P_PARITY_IDX] + 0.25 + grid_part = grid_part[ + torch.where( + ((grid_part[:, :7] < 0).sum(dim=-1) <= 1) + & (grid_part[:, :7].min(dim=-1).values >= -0.5) + )[0] + ] + abs_grid = get_abs_grid() + self.register_buffer("grid_part", grid_part) + self.register_buffer("grid_part_norm", grid_part.norm(dim=-1).square()) + self.register_buffer("grid_abs_odd", abs_grid.sum(dim=-1) % 2 == 1) + self.register_buffer("part_abs_map", self.round(grid_part.abs(), abs_grid, abs_grid.norm(dim=-1).square())[1]) + self.register_buffer("bit_map", 2 ** torch.arange(8)) + + def round(self, x: Tensor, grid: Tensor, grid_norm: Tensor) -> tuple[Tensor, Tensor]: + idx = (2 * x @ grid.t() - grid_norm).argmax(dim=-1) + return grid[idx], idx + + def round_topk(self, x: Tensor, grid: Tensor, grid_norm: Tensor, k: int = 2) -> tuple[Tensor, Tensor]: + scores = 2 * x @ grid.t() - grid_norm + top_idx = scores.topk(k, dim=-1).indices + return grid[top_idx], top_idx + + def fast_quantize_part(self, x: Tensor, parity: bool) -> tuple[Tensor, Tensor, Tensor]: + x_part = torch.abs(x) + x_odd = torch.where((x < 0).sum(dim=-1) % 2 != 0)[0] + x_part[x_odd, 7] = -x_part[x_odd, 7] + mask = 1 - 2 * (x < 0).to(torch.float32) + mask[x_odd, 7] = -mask[x_odd, 7] + rounded, rounded_idx = self.round(x_part, self.grid_part, self.grid_part_norm) + vals = rounded * mask + err = (x - vals).norm(dim=-1) + abs_idx = self.part_abs_map[rounded_idx] + sign_mask = ((rounded < 0) ^ (mask < 0))[:, [0, 2, 4, 6, 1, 3, 5, 7]] + sign_mask[:, 7] ^= self.grid_abs_odd[abs_idx] + sign_mask[:, 0] ^= parity + idx = (abs_idx << 8) + (sign_mask * self.bit_map).sum(dim=-1).int() + return vals, idx, err + + def fast_quantize_part_topk(self, x: Tensor, parity: bool, k: int = 2) -> tuple[Tensor, Tensor, Tensor]: + x_part = torch.abs(x) + x_odd = torch.where((x < 0).sum(dim=-1) % 2 != 0)[0] + x_part[x_odd, 7] = -x_part[x_odd, 7] + mask = 1 - 2 * (x < 0).to(torch.float32) + mask[x_odd, 7] = -mask[x_odd, 7] + rounded, rounded_idx = self.round_topk(x_part, self.grid_part, self.grid_part_norm, k=k) + vals = rounded * mask.unsqueeze(1) + err = (x.unsqueeze(1) - vals).square().sum(dim=-1) + abs_idx = self.part_abs_map[rounded_idx] + sign_mask = ((rounded < 0) ^ (mask.unsqueeze(1) < 0))[:, :, [0, 2, 4, 6, 1, 3, 5, 7]] + sign_mask[:, :, 7] ^= self.grid_abs_odd[abs_idx] + sign_mask[:, :, 0] ^= parity + idx = (abs_idx << 8) + (sign_mask * self.bit_map.view(1, 1, -1)).sum(dim=-1).int() + return vals, idx, err + + def quantize(self, x: Tensor) -> tuple[Tensor, Tensor]: + plus_vals, plus_idx, plus_err = self.fast_quantize_part(x + 0.25, True) + minus_vals, minus_idx, minus_err = self.fast_quantize_part(x - 0.25, False) + use_plus = plus_err < minus_err + vals = torch.where(use_plus.unsqueeze(-1), plus_vals - 0.25, minus_vals + 0.25) + idx = torch.where(use_plus, plus_idx, minus_idx) + return vals, idx + + def quantize_top2(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: + plus_vals, plus_idx, plus_err = self.fast_quantize_part_topk(x + 0.25, True, k=2) + minus_vals, minus_idx, minus_err = self.fast_quantize_part_topk(x - 0.25, False, k=2) + cand_vals = torch.cat((plus_vals - 0.25, minus_vals + 0.25), dim=1) + cand_idx = torch.cat((plus_idx, minus_idx), dim=1) + cand_err = torch.cat((plus_err, minus_err), dim=1) + order = cand_err.topk(2, dim=1, largest=False).indices + gather_vals = cand_vals.gather(1, order.unsqueeze(-1).expand(-1, -1, self.codesz)) + gather_idx = cand_idx.gather(1, order) + gather_err = cand_err.gather(1, order) + return gather_vals[:, 0], gather_idx[:, 0], gather_vals[:, 1], gather_err[:, 1] - gather_err[:, 0] + + def decode(self, idxs: Tensor) -> Tensor: + idxs_long = idxs.long().reshape(-1) + vals = self.grid.index_select(0, idxs_long) + return vals.view(*idxs.shape, self.codesz) + + +_E8P_CODEBOOK_CACHE: dict[str, E8P12Codebook] = {} + + +def _e8p_cache_key(device: torch.device | None) -> str: + if device is None: + return "cpu" + return f"{device.type}:{device.index}" + + +def _get_e8p_codebook(device: torch.device | None) -> E8P12Codebook: + key = _e8p_cache_key(device) + cached = _E8P_CODEBOOK_CACHE.get(key) + if cached is not None: + return cached + codebook = E8P12Codebook() + if device is not None: + codebook = codebook.to(device=device, dtype=torch.float32) + else: + codebook = codebook.to(dtype=torch.float32) + codebook.eval() + _E8P_CODEBOOK_CACHE[key] = codebook + return codebook + + +@torch.no_grad() +def _quantize_e8p_blocks(blocks: Tensor, lattice_scale: float) -> tuple[Tensor, Tensor]: + codebook = _get_e8p_codebook(blocks.device) + scaled = blocks.to(dtype=torch.float32) * float(lattice_scale) + vals, idxs = codebook.quantize(scaled) + return vals / float(lattice_scale), idxs.long() + + +@torch.no_grad() +def _decode_e8p_blocks( + idxs: Tensor, + lattice_scale: float, + *, + device: torch.device, + dtype: torch.dtype, +) -> Tensor: + codebook = _get_e8p_codebook(device) + vals = codebook.decode(idxs.to(device=device)) + return (vals / float(lattice_scale)).to(dtype=dtype) + + +@torch.no_grad() +def _fast_quantize_weight_target(h: Hyperparameters, name: str, weight: Tensor) -> Tensor: + weight_f32 = weight.detach().float() + blocks, shape = _blockify_weight(weight_f32, h.codebook_block_dim) + sign_vec = _hadamard_sign_vector(name, h.codebook_block_dim, device=weight.device, dtype=torch.float32) + rotated_blocks = hadamard_rotate_blocks(blocks, sign_vec, enabled=h.codebook_use_hadamard) + block_norms = rotated_blocks.norm(dim=-1, keepdim=True).clamp_min(1e-8) + codebook = _get_e8p_codebook(weight.device) + target_radius = float(codebook.grid.norm(dim=-1).median().item()) + proxy_blocks = rotated_blocks / block_norms * target_radius + quantized_proxy, _ = _quantize_e8p_blocks(proxy_blocks, h.codebook_lattice_scale) + quantized_dirs = F.normalize(quantized_proxy.to(dtype=torch.float32), dim=-1) + recon_rotated = quantized_dirs * block_norms + recon_blocks = hadamard_unrotate_blocks( + recon_rotated, + sign_vec, + enabled=h.codebook_use_hadamard, + ) + return _unblockify_weight(recon_blocks, shape).to(dtype=weight.dtype) + + +def _log_prior_from_assignments( + assignments: Tensor, + num_entries: int, + *, + smoothing: float = 1.0, +) -> Tensor: + flat = assignments.reshape(-1).to(dtype=torch.int64) + counts = torch.bincount(flat, minlength=num_entries).to(dtype=torch.float32) + probs = (counts + float(smoothing)) / float(counts.sum().item() + smoothing * num_entries) + return torch.log(probs) + + +@torch.no_grad() +def _metric_optimal_codewords( + rotated_blocks: Tensor, + metrics: Tensor, + candidates: Tensor, + *, + candidate_batch_size: int = 4096, + assignment_log_prior: Tensor | None = None, + assignment_entropy_weight: float = 0.0, +) -> tuple[Tensor, Tensor]: + if rotated_blocks.ndim != 3: + raise ValueError(f"Expected rotated blocks with shape [rows, positions, dim], got {tuple(rotated_blocks.shape)}") + expected_metrics_shape = (rotated_blocks.shape[1], rotated_blocks.shape[2], rotated_blocks.shape[2]) + if metrics.ndim != 3 or metrics.shape != expected_metrics_shape: + raise ValueError(f"Expected metrics with shape {expected_metrics_shape}, got {tuple(metrics.shape)}") + if candidates.ndim != 2 or candidates.shape[1] != rotated_blocks.shape[2]: + raise ValueError(f"Expected candidates [entries, {rotated_blocks.shape[2]}], got {tuple(candidates.shape)}") + if assignment_log_prior is not None and assignment_log_prior.shape != (candidates.shape[0],): + raise ValueError( + f"Expected assignment_log_prior with shape ({candidates.shape[0]},), got {tuple(assignment_log_prior.shape)}" + ) + num_rows, num_positions, _ = rotated_blocks.shape + selected_idx = torch.empty((num_rows, num_positions), device=rotated_blocks.device, dtype=torch.long) + selected_codewords = torch.empty_like(rotated_blocks, dtype=torch.float32) + for pos in range(num_positions): + metric = metrics[pos] + x = rotated_blocks[:, pos, :].to(dtype=torch.float32) + x_metric = torch.matmul(x, metric) + best_improvement = torch.full((num_rows,), -torch.inf, device=x.device, dtype=torch.float32) + best_idx = torch.zeros((num_rows,), device=x.device, dtype=torch.long) + for start in range(0, candidates.shape[0], candidate_batch_size): + cand = candidates[start:start + candidate_batch_size] + cand_metric = torch.matmul(cand, metric) + numer = torch.matmul(x_metric, cand.t()) + denom = (cand_metric * cand).sum(dim=-1).clamp_min(1e-8) + # x^T M x is candidate-independent, so minimizing the optimal metric error is + # equivalent to maximizing the gain from the best non-negative scale. + improvement = numer.clamp_min(0.0).square() / denom.unsqueeze(0) + if assignment_log_prior is not None and assignment_entropy_weight > 0: + improvement = improvement + float(assignment_entropy_weight) * assignment_log_prior[ + start:start + cand.shape[0] + ].unsqueeze(0) + chunk_best_improvement, chunk_best_idx = improvement.max(dim=1) + update = chunk_best_improvement > best_improvement + if update.any(): + best_improvement = torch.where(update, chunk_best_improvement, best_improvement) + best_idx = torch.where(update, start + chunk_best_idx, best_idx) + selected_idx[:, pos] = best_idx + selected_codewords[:, pos, :] = candidates.index_select(0, best_idx) + return selected_codewords, selected_idx + + +def collect_hessians( + model: nn.Module, + train_loader: DistributedTokenLoader, + h: Hyperparameters, + device: torch.device, + target_names: set[str], + n_calibration_batches: int = 64, +) -> dict[str, Tensor]: + """Run calibration batches and collect H = X^T X for selected CastedLinear layers.""" + if n_calibration_batches <= 0 or not target_names: + return {} + hessians: dict[str, Tensor] = {} + hooks = [] + was_training = model.training + + def make_hook(name: str): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + return hook_fn + + for module_name, module in model.named_modules(): + if not isinstance(module, CastedLinear): + continue + weight_name = f"{module_name}.weight" + if weight_name in target_names: + hooks.append(module.register_forward_hook(make_hook(weight_name))) + + model.eval() + with torch.no_grad(): + for _ in range(n_calibration_batches): + x, _ = train_loader.next_batch( + h.train_batch_tokens, + h.train_seq_len, + h.grad_accum_steps, + ) + model.forward_logits(x) + + for hook in hooks: + hook.remove() + + world = dist.get_world_size() if dist.is_available() and dist.is_initialized() else 1 + for name, hessian in hessians.items(): + if world > 1: + dist.all_reduce(hessian, op=dist.ReduceOp.SUM) + hessians[name] = hessian.cpu() / float(max(world * n_calibration_batches, 1)) + + if was_training: + model.train() + return hessians + + +def _identity_block_metrics(shape: tuple[int, int], block_dim: int, device: torch.device) -> Tensor: + num_positions = shape[1] // block_dim + eye = torch.eye(block_dim, device=device, dtype=torch.float32) + return eye.unsqueeze(0).repeat(num_positions, 1, 1) + + +def _block_metrics_from_hessian( + shape: tuple[int, int], + block_dim: int, + sign_vec: Tensor, + hessian: Tensor | None, + *, + use_hadamard: bool, + damp_factor: float, + device: torch.device, +) -> Tensor: + if hessian is None or hessian.ndim != 2 or hessian.shape != (shape[1], shape[1]): + return _identity_block_metrics(shape, block_dim, device) + num_positions = shape[1] // block_dim + hessian_work = hessian.to(device=device, dtype=torch.float32).clone() + diag = hessian_work.diag() + dead = diag <= 0 + if dead.any(): + hessian_work[dead, dead] = 1.0 + damp = float(damp_factor * hessian_work.diag().mean().clamp_min(1e-8).item()) + hessian_work.diagonal().add_(damp) + reshaped = hessian_work.view(num_positions, block_dim, num_positions, block_dim) + block_idx = torch.arange(num_positions, device=device) + block_metrics = reshaped[block_idx, :, block_idx, :] + if use_hadamard: + unrotate = hadamard_unrotate_blocks( + torch.eye(block_dim, device=device, dtype=torch.float32), + sign_vec, + enabled=True, + ) + block_metrics = torch.matmul( + unrotate.unsqueeze(0), + torch.matmul(block_metrics, unrotate.t().unsqueeze(0)), + ) + block_metrics = 0.5 * (block_metrics + block_metrics.transpose(-1, -2)) + block_metrics.diagonal(dim1=-2, dim2=-1).add_(1e-6) + return block_metrics + + +def _metric_optimal_scales(rotated_blocks: Tensor, codebook_blocks: Tensor, metrics: Tensor) -> Tensor: + numer = torch.einsum("npd,pde,npe->np", codebook_blocks, metrics, rotated_blocks) + denom = torch.einsum("npd,pde,npe->np", codebook_blocks, metrics, codebook_blocks).clamp_min(1e-8) + scales = (numer / denom).clamp_min(1e-8) + max_fp16 = float(torch.finfo(torch.float16).max) + return scales.clamp(max=max_fp16).unsqueeze(-1) + + +@torch.no_grad() +def _metric_optimal_e8p_codewords( + rotated_blocks: Tensor, + metrics: Tensor, + lattice_scale: float, + *, + candidate_batch_size: int = 4096, + assignment_log_prior: Tensor | None = None, + assignment_entropy_weight: float = 0.0, +) -> tuple[Tensor, Tensor]: + codebook = _get_e8p_codebook(rotated_blocks.device) + candidates = codebook.grid.to(dtype=torch.float32) / float(lattice_scale) + return _metric_optimal_codewords( + rotated_blocks, + metrics, + candidates, + candidate_batch_size=candidate_batch_size, + assignment_log_prior=assignment_log_prior, + assignment_entropy_weight=assignment_entropy_weight, + ) + + +class FixedCodebookSoftSnap: + def __init__(self, h: Hyperparameters, model: nn.Module): + _validate_codebook_hparams(h) + self.h = h + self.target_modules: list[tuple[str, CastedLinear]] = [] + self.target_weights = 0 + self.active = False + self.last_snap_step: int | None = None + for module_name, module in model.named_modules(): + if not isinstance(module, CastedLinear): + continue + weight_name = f"{module_name}.weight" + if _should_codebook_quantize(weight_name, module.weight, h): + self.target_modules.append((weight_name, module)) + self.target_weights += int(module.weight.numel()) + + def should_activate(self, frac: float) -> bool: + return self.h.codebook_soft_snap_enabled and not self.active and frac >= self.h.codebook_soft_snap_start_frac + + @torch.no_grad() + def apply(self, step: int, *, force: bool = False) -> bool: + if not self.active or not self.target_modules: + return False + if not force and self.last_snap_step is not None: + if (step - self.last_snap_step) < self.h.codebook_soft_snap_update_every: + return False + alpha = float(self.h.codebook_soft_snap_alpha) + if alpha <= 0.0: + return False + for name, module in self.target_modules: + q_weight = _fast_quantize_weight_target(self.h, name, module.weight) + module.weight.lerp_(q_weight, alpha) + self.last_snap_step = step + return True + + def activate(self, step: int, frac: float) -> None: + if self.active: + return + self.active = True + if self.h.is_main_process: + log( + f"codebook_soft_snap:enabled step:{step}/{self.h.iterations} " + f"start_frac:{self.h.codebook_soft_snap_start_frac:.3f} frac:{frac:.3f} " + f"mode:fast_direction " + f"alpha:{self.h.codebook_soft_snap_alpha:.4f} " + f"target_tensors:{len(self.target_modules)} target_weights:{self.target_weights} " + f"update_every:{self.h.codebook_soft_snap_update_every}" + ) + + +class FixedCodebookPenalty: + def __init__(self, h: Hyperparameters, model: nn.Module): + _validate_codebook_hparams(h) + self.h = h + self.target_modules: list[tuple[str, CastedLinear]] = [] + self.target_weights = 0 + self.active = False + self.last_refresh_step: int | None = None + self.cached_targets: list[dict[str, object]] = [] + for module_name, module in model.named_modules(): + if not isinstance(module, CastedLinear): + continue + weight_name = f"{module_name}.weight" + if _should_codebook_quantize(weight_name, module.weight, h): + self.target_modules.append((weight_name, module)) + self.target_weights += int(module.weight.numel()) + + def should_activate(self, frac: float) -> bool: + return self.h.codebook_penalty_enabled and not self.active and frac >= self.h.codebook_penalty_start_frac + + def activate(self, step: int, frac: float) -> None: + if self.active: + return + self.active = True + if self.h.is_main_process: + log( + f"codebook_penalty:enabled step:{step}/{self.h.iterations} " + f"start_frac:{self.h.codebook_penalty_start_frac:.3f} frac:{frac:.3f} " + f"weight:{self.h.codebook_penalty_weight:.5f} " + f"mode:fast_direction_target " + f"refresh_every:{self.h.codebook_penalty_refresh_every} " + f"target_tensors:{len(self.target_modules)} target_weights:{self.target_weights}" + ) + + def _current_weight(self, frac: float) -> float: + if frac <= self.h.codebook_penalty_start_frac: + return 0.0 + return float(self.h.codebook_penalty_weight) + + @torch.no_grad() + def _refresh_targets(self, step: int) -> None: + refreshed: list[dict[str, object]] = [] + for name, module in self.target_modules: + q_weight = _fast_quantize_weight_target(self.h, name, module.weight) + refreshed.append( + { + "name": name, + "module": module, + "target_weight": q_weight.detach(), + } + ) + self.cached_targets = refreshed + self.last_refresh_step = step + + def penalty(self, step: int, frac: float) -> Tensor | None: + if not self.active: + return None + weight = self._current_weight(frac) + if weight <= 0.0: + return None + if self.last_refresh_step is None or (step - self.last_refresh_step) >= self.h.codebook_penalty_refresh_every: + self._refresh_targets(step) + if not self.cached_targets: + return None + losses: list[Tensor] = [] + for item in self.cached_targets: + module = item["module"] + target_weight = item["target_weight"] + diff = module.weight.float() - target_weight.float() + mse = diff.square().mean() + denom = target_weight.float().square().mean().clamp_min(1e-8) + losses.append(mse / denom) + if not losses: + return None + return float(weight) * torch.stack(losses).mean() + + +class CodebookQuantizer: + def __init__(self, h: Hyperparameters, model: nn.Module): + _validate_codebook_hparams(h) + self.h = h + self.states: dict[str, dict[str, object]] = {} + self.target_modules: list[tuple[str, CastedLinear]] = [] + self.target_names: set[str] = set() + self.last_fit_summary: dict[str, float] = {} + for module_name, module in model.named_modules(): + if not isinstance(module, CastedLinear): + continue + weight_name = f"{module_name}.weight" + if _should_codebook_quantize(weight_name, module.weight, h): + self.target_modules.append((weight_name, module)) + self.target_names.add(weight_name) + + @torch.no_grad() + def _prepare_rotated_blocks( + self, + name: str, + module: CastedLinear, + ) -> tuple[Tensor, tuple[int, int], Tensor]: + weight = module.weight.detach().float() + blocks, shape = _blockify_weight(weight, self.h.codebook_block_dim) + sign_vec = _hadamard_sign_vector(name, self.h.codebook_block_dim, device=blocks.device, dtype=torch.float32) + rotated_blocks = hadamard_rotate_blocks(blocks, sign_vec, enabled=self.h.codebook_use_hadamard) + return rotated_blocks, shape, sign_vec + + def _fit_tensor(self, name: str, module: CastedLinear, hessian: Tensor | None) -> tuple[str, dict[str, float]]: + weight = module.weight.detach().float() + rotated_blocks, shape, sign_vec = self._prepare_rotated_blocks(name, module) + blocks, _ = _blockify_weight(weight, self.h.codebook_block_dim) + num_rows, num_cols = shape + num_positions = num_cols // self.h.codebook_block_dim + rotated_grid = rotated_blocks.view(num_rows, num_positions, self.h.codebook_block_dim) + metrics = _block_metrics_from_hessian( + shape, + self.h.codebook_block_dim, + sign_vec, + hessian, + use_hadamard=self.h.codebook_use_hadamard, + damp_factor=self.h.codebook_hessian_damp, + device=blocks.device, + ) + initial_dirs, initial_idx = _metric_optimal_e8p_codewords( + rotated_grid, + metrics, + self.h.codebook_lattice_scale, + ) + if self.h.codebook_assignment_entropy_weight > 0: + e8p_entries = int(_get_e8p_codebook(rotated_blocks.device).grid.shape[0]) + assignment_log_prior = _log_prior_from_assignments(initial_idx, e8p_entries) + fixed_dirs, fixed_idx = _metric_optimal_e8p_codewords( + rotated_grid, + metrics, + self.h.codebook_lattice_scale, + assignment_log_prior=assignment_log_prior, + assignment_entropy_weight=self.h.codebook_assignment_entropy_weight, + ) + else: + assignment_log_prior = None + fixed_dirs, fixed_idx = initial_dirs, initial_idx + scales = _metric_optimal_scales(rotated_grid, fixed_dirs, metrics) + scales_fp16 = scales.to(dtype=torch.float16) + recon_rotated = fixed_dirs * scales_fp16.to(dtype=torch.float32) + recon_blocks = hadamard_unrotate_blocks( + recon_rotated.view(-1, self.h.codebook_block_dim), + sign_vec, + enabled=self.h.codebook_use_hadamard, + ) + diff = blocks - recon_blocks + mse = diff.square().mean() + energy = blocks.square().mean().clamp_min(1e-12) + stats = { + "num_weights": float(weight.numel()), + "mse": float(mse.item()), + "rel_mse": float((mse / energy).item()), + "scale_min": float(scales_fp16.min().item()), + "scale_mean": float(scales_fp16.float().mean().item()), + "scale_max": float(scales_fp16.max().item()), + } + state: dict[str, object] = { + "shape": shape, + "block_dim": self.h.codebook_block_dim, + "fixed_idx": fixed_idx.reshape(-1).to(dtype=torch.uint16).cpu().contiguous(), + "scales": scales_fp16.reshape(-1).cpu().contiguous(), + "stats": stats, + } + if assignment_log_prior is not None: + probs = assignment_log_prior.exp() + stats["assignment_prior_entropy"] = float((-(probs * assignment_log_prior).sum()).item()) + self.states[name] = state + return name, stats + + @torch.no_grad() + def fit(self, model: nn.Module, hessians: dict[str, Tensor]) -> None: + if not self.target_modules: + if self.h.is_main_process: + log("codebook:no eligible tensors found") + return + total_fp_weights = sum( + int(t.numel()) + for _, t in model.state_dict().items() + if t.is_floating_point() + ) + fit_items = [self._fit_tensor(name, module, hessians.get(name)) for name, module in self.target_modules] + target_weights = sum(int(stats["num_weights"]) for _, stats in fit_items) + weighted_rel_mse = sum(float(stats["rel_mse"]) * int(stats["num_weights"]) for _, stats in fit_items) + self.last_fit_summary = { + "target_tensors": float(len(fit_items)), + "target_weights": float(target_weights), + "coverage": target_weights / max(total_fp_weights, 1), + "rel_mse": weighted_rel_mse / max(target_weights, 1), + "target_bpw": 4.0, + } + if self.h.is_main_process: + log( + f"codebook:fit target_tensors:{len(fit_items)} target_weights:{target_weights} " + f"coverage:{self.last_fit_summary['coverage']:.4%} rel_mse:{self.last_fit_summary['rel_mse']:.6e} " + f"target_bpw:{self.last_fit_summary['target_bpw']:.4f}" + ) + + @torch.no_grad() + def _reconstruct_tensor_from_state(self, name: str, state: dict[str, object], *, dtype: torch.dtype = torch.float32) -> Tensor: + shape = tuple(int(x) for x in state["shape"]) + block_dim = int(state["block_dim"]) + idx = state["fixed_idx"].to(dtype=torch.int64) + scales = state["scales"].to(dtype=torch.float32).view(-1, 1) + fixed_blocks = _decode_e8p_blocks( + idx, + float(self.h.codebook_lattice_scale), + device=torch.device("cpu"), + dtype=torch.float32, + ) + rotated_blocks = fixed_blocks * scales + sign_vec = _hadamard_sign_vector(name, block_dim, device=torch.device("cpu"), dtype=torch.float32) + blocks = hadamard_unrotate_blocks( + rotated_blocks, + sign_vec, + enabled=bool(self.h.codebook_use_hadamard), + ) + return _unblockify_weight(blocks, shape).to(dtype) + + @torch.no_grad() + def _select_outliers(self, original: Tensor, reconstructed: Tensor) -> tuple[Tensor | None, Tensor | None, Tensor | None]: + requested = 0 + if self.h.codebook_outlier_frac > 0: + requested = int(math.ceil(float(self.h.codebook_outlier_frac) * float(original.numel()))) + if self.h.codebook_outlier_max_count > 0: + requested = ( + min(requested, int(self.h.codebook_outlier_max_count)) + if requested > 0 + else int(self.h.codebook_outlier_max_count) + ) + requested = min(requested, int(original.numel())) + if requested <= 0: + return None, None, None + flat_orig = original.reshape(-1).to(dtype=torch.float32) + flat_recon = reconstructed.reshape(-1).to(dtype=torch.float32) + err = (flat_orig - flat_recon).square() + if requested >= flat_orig.numel(): + idx = torch.arange(flat_orig.numel(), dtype=torch.int32) + else: + idx = torch.topk(err, k=requested, largest=True, sorted=False).indices.to(dtype=torch.int64) + idx = idx.index_select(0, torch.argsort(idx)).to(dtype=torch.int32).cpu() + vals = flat_orig.index_select(0, idx.to(dtype=torch.int64)).cpu().contiguous() + max_abs = float(vals.abs().max().item()) if vals.numel() > 0 else 0.0 + if max_abs <= 1e-12: + scale = torch.ones((1,), dtype=torch.float16) + q = torch.zeros(vals.numel(), dtype=torch.int8) + else: + scale_value = max_abs / 127.0 + q = torch.round(vals / scale_value).clamp_(-127, 127).to(dtype=torch.int8).contiguous() + scale = torch.tensor([scale_value], dtype=torch.float16) + return idx.contiguous(), q.cpu().contiguous(), scale.cpu().contiguous() + + def build_export( + self, + state_dict: dict[str, Tensor], + ) -> tuple[dict[str, Tensor], dict[str, object], dict[str, object]]: + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + export_stats: dict[str, object] = {"tensors": {}} + total_payload_bytes = 0 + target_payload_bytes = 0 + outlier_payload_bytes = 0 + outlier_weights = 0 + int8_fallback_payload_bytes = 0 + int8_fallback_weights = 0 + passthrough_payload_bytes = 0 + passthrough_weights = 0 + target_weights = 0 + total_fp_weights = sum(int(t.numel()) for _, t in state_dict.items() if t.is_floating_point()) + + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + if name in self.states: + state = self.states[name] + idx = state["fixed_idx"] + scales = state["scales"] + scale_payload, scale_meta = _quantize_codebook_scales(scales, self.h.codebook_scale_bits) + result[name + ".idx"] = idx + result[name + ".scale"] = scale_payload + tensor_outlier_count = 0 + tensor_outlier_payload_bytes = 0 + if self.h.codebook_outlier_frac > 0 or self.h.codebook_outlier_max_count > 0: + reconstructed = self._reconstruct_tensor_from_state(name, state, dtype=torch.float32) + outlier_idx, outlier_q, outlier_scale = self._select_outliers(t.float(), reconstructed) + if ( + outlier_idx is not None + and outlier_q is not None + and outlier_scale is not None + and outlier_idx.numel() > 0 + ): + result[name + ".outlier_idx"] = outlier_idx + result[name + ".outlier_q"] = outlier_q + result[name + ".outlier_scale"] = outlier_scale + tensor_outlier_count = int(outlier_idx.numel()) + tensor_outlier_payload_bytes = ( + outlier_idx.numel() * outlier_idx.element_size() + + outlier_q.numel() * outlier_q.element_size() + + outlier_scale.numel() * outlier_scale.element_size() + ) + meta[name] = { + "type": "codebook_e8p", + "shape": list(state["shape"]), + "block_dim": int(state["block_dim"]), + "hadamard": bool(self.h.codebook_use_hadamard), + "lattice_scale": float(self.h.codebook_lattice_scale), + "outlier_count": tensor_outlier_count, + "outlier_format": "int8" if tensor_outlier_count > 0 else "none", + **scale_meta, + } + payload_bytes = ( + idx.numel() * idx.element_size() + + scale_payload.numel() * scale_payload.element_size() + + tensor_outlier_payload_bytes + ) + total_payload_bytes += payload_bytes + target_payload_bytes += payload_bytes + outlier_payload_bytes += tensor_outlier_payload_bytes + outlier_weights += tensor_outlier_count + target_weights += int(state["stats"]["num_weights"]) + export_stats["tensors"][name] = { + **state["stats"], + "payload_bytes": payload_bytes, + "outlier_payload_bytes": tensor_outlier_payload_bytes, + "outlier_count": tensor_outlier_count, + "scale_bits": int(scale_meta["scale_bits"]), + "scale_format": str(scale_meta["scale_format"]), + "bpw": (8.0 * payload_bytes) / max(float(state["stats"]["num_weights"]), 1.0), + } + export_stats.setdefault("diagnostics", []).append( + _codebook_payload_diagnostics( + name, + idx, + scales, + scale_payload, + tuple(int(x) for x in state["shape"]), + int(state["block_dim"]), + self.h.compressor, + ) + ) + continue + + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + payload_bytes = result[name].numel() * result[name].element_size() + total_payload_bytes += payload_bytes + passthrough_payload_bytes += payload_bytes + if t.is_floating_point(): + passthrough_weights += int(t.numel()) + continue + + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + payload_bytes = result[name].numel() * result[name].element_size() + total_payload_bytes += payload_bytes + passthrough_payload_bytes += payload_bytes + passthrough_weights += int(t.numel()) + continue + + q, s = quantize_float_tensor(t) + result[name + ".q"] = q.cpu().contiguous() + result[name + ".scale"] = s.cpu().contiguous() + meta[name] = {"type": "int8"} + payload_bytes = ( + result[name + ".q"].numel() * result[name + ".q"].element_size() + + result[name + ".scale"].numel() * result[name + ".scale"].element_size() + ) + total_payload_bytes += payload_bytes + int8_fallback_payload_bytes += payload_bytes + int8_fallback_weights += int(t.numel()) + + export_stats["summary"] = { + "target_tensors": len(self.states), + "target_weights": target_weights, + "coverage": target_weights / max(total_fp_weights, 1), + "target_payload_bytes": target_payload_bytes, + "target_bpw": (8.0 * target_payload_bytes) / max(target_weights, 1), + "outlier_weights": outlier_weights, + "outlier_payload_bytes": outlier_payload_bytes, + "int8_fallback_weights": int8_fallback_weights, + "int8_fallback_payload_bytes": int8_fallback_payload_bytes, + "passthrough_weights": passthrough_weights, + "passthrough_payload_bytes": passthrough_payload_bytes, + "payload_bytes_before_torchsave": total_payload_bytes, + "effective_payload_bpw_all_weights": (8.0 * total_payload_bytes) / max(total_fp_weights, 1), + "model_fp_weights": total_fp_weights, + } + return result, meta, export_stats + + +def dequantize_state_dict_codebook( + result: dict[str, Tensor], + meta: dict[str, object], + template_sd: dict[str, Tensor], +) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + if info in ("passthrough", "passthrough_ctrl"): + t = result[name] + if t.dtype == torch.float16 and orig.dtype in (torch.float32, torch.bfloat16): + t = t.to(orig.dtype) + out[name] = t + continue + if not isinstance(info, dict): + raise ValueError(f"Unsupported compression metadata for {name}: {info!r}") + if info.get("type") == "int8": + out[name] = dequantize_int8_rows( + result[name + ".q"], + result[name + ".scale"], + device=torch.device("cpu"), + dtype=orig.dtype, + ) + continue + if info.get("type") not in ("codebook_e8p_fp16_scale", "codebook_e8p"): + raise ValueError(f"Unsupported compression metadata for {name}: {info!r}") + shape = tuple(int(x) for x in info["shape"]) + block_dim = int(info["block_dim"]) + idx = result[name + ".idx"].to(dtype=torch.int64) + if info.get("type") == "codebook_e8p_fp16_scale": + scales = result[name + ".scale"].to(dtype=torch.float32).view(-1, 1) + else: + scales = _dequantize_codebook_scales(result[name + ".scale"], info) + fixed_blocks = _decode_e8p_blocks( + idx, + float(info["lattice_scale"]), + device=torch.device("cpu"), + dtype=torch.float32, + ) + rotated_blocks = fixed_blocks * scales + sign_vec = _hadamard_sign_vector(name, block_dim, device=torch.device("cpu"), dtype=torch.float32) + blocks = hadamard_unrotate_blocks( + rotated_blocks, + sign_vec, + enabled=bool(info.get("hadamard", True)), + ) + restored = _unblockify_weight(blocks, shape).to(orig.dtype) + if int(info.get("outlier_count", 0)) > 0: + flat = restored.reshape(-1) + outlier_idx = result[name + ".outlier_idx"].to(dtype=torch.int64) + outlier_q = result[name + ".outlier_q"].to(dtype=torch.float32).reshape(-1) + outlier_scale = float(result[name + ".outlier_scale"].to(dtype=torch.float32).reshape(-1)[0].item()) + outlier_val = (outlier_q * outlier_scale).to(dtype=orig.dtype) + flat[outlier_idx] = outlier_val + restored = flat.view_as(restored) + out[name] = restored + return out + + +_BSHF_MAGIC = b"BSHF" + + +def _byte_shuffle(data: bytes, stride: int = 2) -> bytes: + """Transpose byte stream by stride position for better compression.""" + if stride <= 1 or len(data) < stride: + return data + src = np.frombuffer(data, dtype=np.uint8) + n = len(src) + out = np.empty(n, dtype=np.uint8) + dest_off = 0 + for pos in range(stride): + chunk = src[pos::stride] + out[dest_off:dest_off + len(chunk)] = chunk + dest_off += len(chunk) + return _BSHF_MAGIC + bytes([stride]) + out.tobytes() + + +def _byte_unshuffle(data: bytes) -> bytes: + """Inverse of _byte_shuffle. Auto-detects BSHF magic header.""" + if len(data) < 5 or data[:4] != _BSHF_MAGIC: + return data + stride = data[4] + if stride < 2: + return data[5:] + payload = np.frombuffer(data, dtype=np.uint8, offset=5) + n = len(payload) + out = np.empty(n, dtype=np.uint8) + src_off = 0 + for pos in range(stride): + chunk_len = n // stride + (1 if pos < n % stride else 0) + out[pos::stride][:chunk_len] = payload[src_off:src_off + chunk_len] + src_off += chunk_len + return out.tobytes() + + +def _compress(data: bytes, compressor: str, byte_shuffle: bool = True) -> bytes: + if byte_shuffle: + data = _byte_shuffle(data) + if compressor == "lzma": + return lzma.compress(data, preset=6) + elif compressor == "brotli": + import brotli + return brotli.compress(data, quality=11) + raise ValueError(f"Unknown compressor: {compressor!r}") + + +def _decompress(data: bytes, compressor: str, byte_shuffle: bool = True) -> bytes: + if compressor == "lzma": + raw = lzma.decompress(data) + elif compressor == "brotli": + import brotli + raw = brotli.decompress(data) + else: + raise ValueError(f"Unknown compressor: {compressor!r}") + if byte_shuffle: + raw = _byte_unshuffle(raw) + return raw + + +def _codebook_scale_dtype(bits: int) -> torch.dtype: + return torch.uint8 if bits <= 8 else torch.uint16 + + +@torch.no_grad() +def _quantize_codebook_scales(scales: Tensor, bits: int) -> tuple[Tensor, dict[str, object]]: + scales_fp32 = scales.to(dtype=torch.float32).clamp_min(1e-12) + if bits >= 16: + return scales.to(dtype=torch.float16).contiguous(), { + "scale_format": "fp16", + "scale_bits": 16, + } + + levels = (1 << bits) - 1 + log_scales = torch.log2(scales_fp32) + log_min = float(log_scales.min().item()) + log_max = float(log_scales.max().item()) + if log_max <= log_min: + scale_codes = torch.zeros_like(scales_fp32, dtype=_codebook_scale_dtype(bits)) + log_step = 0.0 + else: + log_step = (log_max - log_min) / levels + scale_codes = torch.round((log_scales - log_min) / log_step).clamp_(0, levels).to(_codebook_scale_dtype(bits)) + return scale_codes.contiguous().cpu(), { + "scale_format": "log", + "scale_bits": int(bits), + "scale_log_min": log_min, + "scale_log_step": float(log_step), + } + + +@torch.no_grad() +def _dequantize_codebook_scales(scale_payload: Tensor, info: dict[str, object]) -> Tensor: + scale_format = str(info.get("scale_format", "fp16")) + if scale_format == "fp16": + return scale_payload.to(dtype=torch.float32).view(-1, 1) + if scale_format != "log": + raise ValueError(f"Unsupported codebook scale format: {scale_format!r}") + log_min = float(info["scale_log_min"]) + log_step = float(info["scale_log_step"]) + if log_step == 0.0: + return torch.full( + (scale_payload.numel(), 1), + 2.0 ** log_min, + dtype=torch.float32, + ) + logs = log_min + scale_payload.to(dtype=torch.float32).reshape(-1, 1) * log_step + return torch.pow(2.0, logs) + + + + +def _entropy_from_counts(counts: np.ndarray) -> float: + total = int(counts.sum()) + if total <= 0: + return 0.0 + probs = counts[counts > 0].astype(np.float64) / float(total) + return float(-(probs * np.log2(probs)).sum()) + + +def _entropy_from_u8_bytes(data: bytes) -> float: + if not data: + return 0.0 + arr = np.frombuffer(data, dtype=np.uint8) + counts = np.bincount(arr, minlength=256) + return _entropy_from_counts(counts) + + +def _adjacent_corrcoef(arr: np.ndarray) -> float: + if arr.size < 2: + return 1.0 + x0 = arr[:-1].astype(np.float64, copy=False) + x1 = arr[1:].astype(np.float64, copy=False) + x0 = x0 - x0.mean() + x1 = x1 - x1.mean() + denom = math.sqrt(float(np.dot(x0, x0) * np.dot(x1, x1))) + if denom <= 0.0: + return 1.0 + return float(np.dot(x0, x1) / denom) + + +def _codebook_payload_diagnostics( + name: str, + idx: Tensor, + scales: Tensor, + scale_payload: Tensor, + shape: tuple[int, int], + block_dim: int, + compressor: str, +) -> dict[str, float | int | str]: + idx_np = idx.contiguous().numpy().reshape(-1) + idx_codes = idx_np.astype(np.int64, copy=False) + scales_np = scales.contiguous().numpy().reshape(-1).astype(np.float16, copy=False) + scale_payload_np = scale_payload.contiguous().numpy().reshape(-1) + idx_bytes = idx_np.tobytes() + scale_bytes = scale_payload_np.tobytes() + combined_bytes = idx_bytes + scale_bytes + + idx_max = int(idx_codes.max()) if idx_codes.size else 0 + idx_counts = np.bincount(idx_codes, minlength=max(idx_max + 1, 1)) + idx_entropy_bits = _entropy_from_counts(idx_counts) + idx_unique = int(np.count_nonzero(idx_counts)) + + scale_words = scales_np.view(np.uint16) + _, scale_counts = np.unique(scale_words, return_counts=True) + scale_value_entropy_bits = _entropy_from_counts(scale_counts) + + scale_grid = scales_np.astype(np.float32, copy=False).reshape(shape[0], shape[1] // block_dim) + if scale_grid.shape[1] > 1: + scale_deltas = np.diff(scale_grid, axis=1).astype(np.float16, copy=False).reshape(-1) + scale_delta_words = scale_deltas.view(np.uint16) + _, delta_counts = np.unique(scale_delta_words, return_counts=True) + scale_delta_entropy_bits = _entropy_from_counts(delta_counts) + scale_delta_mean_abs = float(np.mean(np.abs(scale_deltas.astype(np.float32, copy=False)))) + scale_adjacent_corr = _adjacent_corrcoef(scale_grid.reshape(-1)) + scale_relative_delta = scale_delta_mean_abs / max(float(np.mean(np.abs(scale_grid))), 1e-12) + else: + scale_delta_entropy_bits = 0.0 + scale_delta_mean_abs = 0.0 + scale_adjacent_corr = 1.0 + scale_relative_delta = 0.0 + + idx_compressed_bytes = len(_compress(idx_bytes, compressor)) + scale_compressed_bytes = len(_compress(scale_bytes, compressor)) + combined_compressed_bytes = len(_compress(combined_bytes, compressor)) + idx_byte_entropy = _entropy_from_u8_bytes(_byte_shuffle(idx_bytes)) + scale_byte_entropy = _entropy_from_u8_bytes(_byte_shuffle(scale_bytes)) + + return { + "name": name, + "idx_unique": idx_unique, + "idx_entropy_bits": idx_entropy_bits, + "idx_byte_entropy_bits": idx_byte_entropy, + "idx_raw_bytes": len(idx_bytes), + "idx_compressed_bytes": idx_compressed_bytes, + "idx_entropy_bound_bytes": (idx_np.size * idx_entropy_bits) / 8.0, + "scale_unique": int(scale_counts.size), + "scale_value_entropy_bits": scale_value_entropy_bits, + "scale_delta_entropy_bits": scale_delta_entropy_bits, + "scale_byte_entropy_bits": scale_byte_entropy, + "scale_raw_bytes": len(scale_bytes), + "scale_compressed_bytes": scale_compressed_bytes, + "scale_entropy_bound_bytes": (scales_np.size * scale_value_entropy_bits) / 8.0, + "scale_delta_mean_abs": scale_delta_mean_abs, + "scale_relative_delta": scale_relative_delta, + "scale_adjacent_corr": scale_adjacent_corr, + "scale_storage_bits": float(8.0 * len(scale_bytes) / max(scale_payload_np.size, 1)), + "combined_raw_bytes": len(combined_bytes), + "combined_compressed_bytes": combined_compressed_bytes, + } + + +def _select_codebook_focus_tensor( + diagnostics: list[dict[str, float | int | str]], + focus: str, +) -> dict[str, float | int | str] | None: + if not diagnostics: + return None + if focus: + for item in diagnostics: + if item["name"] == focus: + return item + for item in diagnostics: + if focus in str(item["name"]): + return item + return max(diagnostics, key=lambda item: float(item["scale_compressed_bytes"])) + + +def _log_codebook_diagnostics(h: Hyperparameters, quant_stats: dict[str, object]) -> None: + diagnostics = list(quant_stats.get("diagnostics", [])) + if not diagnostics: + return + topk = max(int(h.codebook_entropy_summary_topk), 0) + total_idx_compressed = sum(int(item["idx_compressed_bytes"]) for item in diagnostics) + total_scale_compressed = sum(int(item["scale_compressed_bytes"]) for item in diagnostics) + total_combined_compressed = sum(int(item["combined_compressed_bytes"]) for item in diagnostics) + log( + f"codebook:entropy_summary tensors:{len(diagnostics)} " + f"idx_compressed_bytes:{total_idx_compressed} " + f"scale_compressed_bytes:{total_scale_compressed} " + f"combined_compressed_bytes:{total_combined_compressed}" + ) + if topk > 0: + ranked = sorted( + diagnostics, + key=lambda item: (float(item["scale_compressed_bytes"]), float(item["combined_compressed_bytes"])), + reverse=True, + )[:topk] + for item in ranked: + log( + "codebook:tensor " + f"name:{item['name']} " + f"idx_raw:{item['idx_raw_bytes']} idx_zip:{item['idx_compressed_bytes']} " + f"scale_raw:{item['scale_raw_bytes']} scale_zip:{item['scale_compressed_bytes']} " + f"combined_zip:{item['combined_compressed_bytes']} " + f"idx_H:{float(item['idx_entropy_bits']):.3f}bits " + f"scale_H:{float(item['scale_value_entropy_bits']):.3f}bits " + f"scale_dH:{float(item['scale_delta_entropy_bits']):.3f}bits " + f"scale_store:{float(item['scale_storage_bits']):.1f}bits " + f"scale_rel_delta:{float(item['scale_relative_delta']):.5f} " + f"scale_corr:{float(item['scale_adjacent_corr']):.5f}" + ) + focus_item = _select_codebook_focus_tensor(diagnostics, h.codebook_entropy_focus) + if focus_item is None: + return + log( + "codebook:focus " + f"name:{focus_item['name']} " + f"idx_unique:{focus_item['idx_unique']} " + f"idx_H:{float(focus_item['idx_entropy_bits']):.4f}bits " + f"idx_byte_H:{float(focus_item['idx_byte_entropy_bits']):.4f}bits " + f"idx_raw:{focus_item['idx_raw_bytes']} " + f"idx_entropy_bound:{float(focus_item['idx_entropy_bound_bytes']):.1f} " + f"idx_zip:{focus_item['idx_compressed_bytes']}" + ) + log( + "codebook:focus " + f"name:{focus_item['name']} " + f"scale_unique:{focus_item['scale_unique']} " + f"scale_H:{float(focus_item['scale_value_entropy_bits']):.4f}bits " + f"scale_dH:{float(focus_item['scale_delta_entropy_bits']):.4f}bits " + f"scale_byte_H:{float(focus_item['scale_byte_entropy_bits']):.4f}bits " + f"scale_store:{float(focus_item['scale_storage_bits']):.1f}bits " + f"scale_raw:{focus_item['scale_raw_bytes']} " + f"scale_entropy_bound:{float(focus_item['scale_entropy_bound_bytes']):.1f} " + f"scale_zip:{focus_item['scale_compressed_bytes']}" + ) + log( + "codebook:focus " + f"name:{focus_item['name']} " + f"combined_zip:{focus_item['combined_compressed_bytes']} " + f"scale_delta_mean_abs:{float(focus_item['scale_delta_mean_abs']):.6f} " + f"scale_rel_delta:{float(focus_item['scale_relative_delta']):.6f} " + f"scale_adjacent_corr:{float(focus_item['scale_adjacent_corr']):.6f}" + ) + + +def serialize(h: Hyperparameters, base_model: torch.nn.Module, code: str) -> int: + quantizer = CodebookQuantizer(h, base_model) + device = next(base_model.parameters()).device + code_bytes = len(code.encode("utf-8")) + bytes_total = code_bytes + if h.is_main_process: + torch.save(base_model.state_dict(), h.model_path) + model_bytes = os.path.getsize(h.model_path) + log(f"Serialized model: {model_bytes} bytes") + log(f"Code size: {code_bytes} bytes") + + hessians: dict[str, Tensor] = {} + if quantizer.target_names: + if h.is_main_process: + log("codebook:collecting calibration hessians...") + t0 = time.perf_counter() + calib_loader = DistributedTokenLoader(h.train_files, h.rank, h.world_size, device) + hessians = collect_hessians( + base_model, + calib_loader, + h, + device, + quantizer.target_names, + n_calibration_batches=h.codebook_calibration_batches, + ) + if h.is_main_process: + log(f"codebook:collected {len(hessians)} hessians in {time.perf_counter() - t0:.1f}s") + + if h.is_main_process: + t0 = time.perf_counter() + quantizer.fit(base_model, hessians) + log(f"codebook:fit_time:{time.perf_counter() - t0:.1f}s") + quant_result, quant_meta, quant_stats = quantizer.build_export(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta, "s": quant_stats}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = _compress(quant_raw, h.compressor) + quant_file_bytes = len(quant_blob) + bytes_total = quant_file_bytes + code_bytes + with open(h.quantized_model_path, "wb") as f: + f.write(quant_blob) + summary = quant_stats["summary"] + log( + f"Serialized model codebook+{h.compressor}: {quant_file_bytes} bytes " + f"(payload_before_torchsave:{summary['payload_bytes_before_torchsave']} bytes)" + ) + log( + f"Codebook coverage:{summary['coverage']:.4%} target_bpw:{summary['target_bpw']:.4f} " + f"effective_payload_bpw_all_weights:{summary['effective_payload_bpw_all_weights']:.4f}" + ) + log( + f"Codebook scale_codec:{'fp16' if h.codebook_scale_bits >= 16 else 'log'} " + f"scale_bits:{h.codebook_scale_bits} " + f"entropy_weight:{h.codebook_assignment_entropy_weight:.4f}" + ) + log( + f"Codebook outliers frac:{h.codebook_outlier_frac:.6f} " + f"max_count:{h.codebook_outlier_max_count} " + f"outlier_weights:{summary['outlier_weights']} " + f"outlier_bytes:{summary['outlier_payload_bytes']}" + ) + log( + f"Fallback payloads int8_weights:{summary['int8_fallback_weights']} " + f"int8_bytes:{summary['int8_fallback_payload_bytes']} " + f"passthrough_weights:{summary['passthrough_weights']} " + f"passthrough_bytes:{summary['passthrough_payload_bytes']}" + ) + _log_codebook_diagnostics(h, quant_stats) + log(f"Total submission size codebook+{h.compressor}: {bytes_total} bytes") + return bytes_total + + +def deserialize(h: Hyperparameters, device: torch.device) -> GPT: + eval_model = GPT(h).to(device).bfloat16() + restore_fp32_params(eval_model) + + template_sd = eval_model.state_dict() + + with open(h.quantized_model_path, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(_decompress(quant_blob_disk, h.compressor)), + map_location="cpu", + ) + deq_state = dequantize_state_dict_codebook(quant_state["w"], quant_state["m"], template_sd) + eval_model.load_state_dict(deq_state, strict=True) + + return eval_model + + +# ---------------------------------------- +# Evaluation +# ---------------------------------------- + +def _loss_bpb(loss_sum, token_count, byte_count) -> tuple[float, float]: + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + return val_loss, val_bpb + + +def eval_val( + h: Hyperparameters, + device: torch.device, + val_data: ValidationData, + model: nn.Module +) -> tuple[float, float]: + seq_len = h.eval_seq_len + local_batch_tokens = h.val_batch_tokens // (h.world_size * h.grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={h.val_batch_tokens}, WORLD_SIZE={h.world_size}, " + f"GRAD_ACCUM_STEPS={h.grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_data.val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * h.rank) // h.world_size + seq_end = (total_seqs * (h.rank + 1)) // h.world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + # Avoid leaving inference-mode tensors in module caches that training later reuses. + with torch.no_grad(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_data.val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = val_data.base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (val_data.has_leading_space_lut[tgt_ids] & ~val_data.is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + model.train() + return _loss_bpb(val_loss_sum, val_token_count, val_byte_count) + + +def eval_val_sliding( + h: Hyperparameters, + device: torch.device, + val_data: ValidationData, + base_model: nn.Module, + batch_seqs: int = 32 +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + base_model.eval() + logits_fn = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + + seq_len = h.eval_seq_len + context_size = seq_len - h.eval_stride + total_tokens = val_data.val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, h.eval_stride) + if ws + context_size < total_tokens] + + total_windows = len(window_starts) + my_s = (total_windows * h.rank) // h.world_size + my_e = (total_windows * (h.rank + 1)) // h.world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Avoid leaving inference-mode tensors in module caches that training later reuses. + with torch.no_grad(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + we = min(ws + seq_len, total_tokens) + wlen = we - ws + wlens.append(wlen) + chunk = val_data.val_tokens[ws:we + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = logits_fn(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else context_size + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = val_data.base_bytes_lut[tgt].to(torch.float64) + tb += (val_data.has_leading_space_lut[tgt] & ~val_data.is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + base_model.train() + return _loss_bpb(loss_sum, token_count, byte_count) + + +def timed_eval(label: str, fn, *args, **kwargs) -> tuple[float, float]: + torch.cuda.synchronize() + t0 = time.perf_counter() + val_loss, val_bpb = fn(*args, **kwargs) + torch.cuda.synchronize() + elapsed_ms = 1000.0 * (time.perf_counter() - t0) + log(f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms") + return val_loss, val_bpb + + +def run_evals( + h: Hyperparameters, + device: torch.device, + val_data: ValidationData, + eval_model: torch.nn.Module +): + compiled_model = torch.compile(eval_model, dynamic=False, fullgraph=True) + timed_eval("final_codebook_roundtrip", eval_val, h, device, val_data, compiled_model) + if h.sliding_window_enabled: + timed_eval("final_codebook_sliding_window", eval_val_sliding, h, device, val_data, eval_model) + +# ----------------------------- +# Training +# ----------------------------- + +def train_model(h: Hyperparameters, device: torch.device, val_data: ValidationData) -> None: + # Set up model + base_model = GPT(h).to(device).bfloat16() + restore_fp32_params(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + if h.distributed: + model = DDP(compiled_model, device_ids=[h.local_rank], broadcast_buffers=False) + else: + model = compiled_model + log(f"model_params:{sum(p.numel() for p in base_model.parameters())}") + + # Set up optimizer and load train data + optimizers = Optimizers(h, base_model) + train_loader = DistributedTokenLoader( h.train_files, h.rank, h.world_size, device) + soft_snap = FixedCodebookSoftSnap(h, base_model) if h.codebook_soft_snap_enabled else None + codebook_penalty = FixedCodebookPenalty(h, base_model) if h.codebook_penalty_enabled else None + + # Helper functions for training + max_wallclock_ms = 1000.0 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None + if max_wallclock_ms is not None and h.codebook_reserve_seconds > 0: + max_wallclock_ms = max(max_wallclock_ms - h.codebook_reserve_seconds * 1000.0, 0.0) + log(f"codebook:reserving {h.codebook_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms") + + def training_frac(step: int, elapsed_ms: float) -> float: + """Fraction of training completed (0 to 1), using step or wallclock.""" + if max_wallclock_ms is None: + return step / max(h.iterations, 1) + return elapsed_ms / max(max_wallclock_ms, 1e-9) + + def lr_mul(frac: float) -> float: + if h.warmdown_frac <= 0: + return 1.0 + if frac >= 1.0 - h.warmdown_frac: + return max((1.0 - frac) / h.warmdown_frac, h.min_lr) + return 1.0 + + def step_fn(step, lr_scale, frac): + optimizers.zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(h.grad_accum_steps): + if h.distributed: + model.require_backward_grad_sync = micro_step == h.grad_accum_steps - 1 + x, y = train_loader.next_batch(h.train_batch_tokens, h.train_seq_len, h.grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss / h.grad_accum_steps).backward() + train_loss /= h.grad_accum_steps + + momentum_frac = min(step / h.muon_momentum_warmup_steps, 1.0) if h.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - momentum_frac) * h.muon_momentum_warmup_start + momentum_frac * h.muon_momentum + for group in optimizers.optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * lr_scale + + codebook_penalty_loss = None + if codebook_penalty is not None: + codebook_penalty_loss = codebook_penalty.penalty(step, frac) + if codebook_penalty_loss is not None: + codebook_penalty_loss.backward() + + if h.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), h.grad_clip_norm) + + optimizers.step() + if soft_snap is not None: + soft_snap.apply(step + 1) + return train_loss, codebook_penalty_loss + + # Model warmup + if h.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0, 0.0) + if warmup_step <= 5 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == h.warmup_steps: + log(f"warmup_step: {warmup_step + 1}/{h.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + optimizers.zero_grad_all() + if h.distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader( + h.train_files, h.rank, h.world_size, device) + + # Training loop + use_ema = h.ema_decay < 1.0 + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} if use_ema else None + ema_decay = h.ema_decay + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == h.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (h.val_loss_every > 0 and step % h.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val(h, device, val_data, model) + log(f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}") + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < h.iterations: + log( + f"stopping_early: wallclock_cap train_time: {training_time_ms:.0f}ms " + f"step: {step}/{h.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + frac = training_frac(step, elapsed_ms) + if soft_snap is not None and soft_snap.should_activate(frac): + soft_snap.activate(step, frac) + if codebook_penalty is not None and codebook_penalty.should_activate(frac): + codebook_penalty.activate(step, frac) + scale = lr_mul(frac) + train_loss, codebook_penalty_loss = step_fn(step, scale, frac) + + if use_ema: + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + should_log_train = ( + h.train_log_every > 0 + and (step <= 5 or step % h.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + tok_per_sec = step * h.train_batch_tokens / (approx_training_time_ms / 1000.0) + train_loss_value = train_loss.item() + total_loss_value = train_loss_value + codebook_penalty_suffix = "" + if codebook_penalty_loss is not None: + codebook_penalty_loss_value = codebook_penalty_loss.item() + total_loss_value += codebook_penalty_loss_value + codebook_penalty_suffix = f" codebook_penalty_loss: {codebook_penalty_loss_value:.5f}" + log( + f"{step}/{h.iterations} train_loss: {train_loss_value:.4f} total_loss: {total_loss_value:.4f} " + f"train_time: {approx_training_time_ms / 60000:.1f}m tok/s: {tok_per_sec:.0f}" + f"{codebook_penalty_suffix}" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if h.distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + if use_ema: + log("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + + return base_model, compiled_model + + +def train_and_eval(h: Hyperparameters, device: torch.device) -> None: + random.seed(h.seed) + np.random.seed(h.seed) + torch.manual_seed(h.seed) + torch.cuda.manual_seed_all(h.seed) + + val_data = ValidationData(h, device) + log(f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob('fineweb_train_*.bin')))}") + log(f"val_tokens: {val_data.val_tokens.numel() - 1}") + + base_model, compiled_model = train_model(h, device, val_data) + timed_eval("pre-codebook post-ema" if h.ema_decay < 1.0 else "pre-codebook", eval_val, h, device, val_data, compiled_model) + + serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) + if h.distributed: + dist.barrier() + eval_model = deserialize(h, device) + + run_evals(h, device, val_data, eval_model) + + +def main(): + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision("high") + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + torch._dynamo.config.optimize_ddp = False + + h = Hyperparameters() + set_logging_hparams(h) + if h.is_main_process: + os.makedirs("logs", exist_ok=True) + log(100 * "=", console=False) + log("Hyperparameters:", console=True) + for k, v in sorted(vars(type(h)).items()): + if not k.startswith("_"): + log(f" {k}: {v}", console=True) + log(Path(__file__).read_text(encoding="utf-8"), console=False) + log("=" * 100, console=False) + log(f"Running Python {sys.version}", console=False) + log(f"Running PyTorch {torch.__version__}", console=False) + log( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log("=" * 100, console=False) + + train_and_eval(h, device) + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Tue Apr 7 05:50:03 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 47C P0 129W / 700W | 2049MiB / 81559MiB | 1% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 39C P0 123W / 700W | 2049MiB / 81559MiB | 5% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 36C P0 120W / 700W | 1541MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 43C P0 123W / 700W | 1521MiB / 81559MiB | 4% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 46C P0 125W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 37C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 45C P0 126W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 36C P0 120W / 700W | 1521MiB / 81559MiB | 1% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +train_shards: 80 +val_tokens: 44847104 +model_params:40173675 +codebook:reserving 30s, effective=570000ms +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +0/20000 val_loss: 8.3149 val_bpb: 3.5686 +1/20000 train_loss: 8.3164 total_loss: 8.3164 train_time: 0.0m tok/s: 7181548 +2/20000 train_loss: 12.2719 total_loss: 12.2719 train_time: 0.0m tok/s: 7083493 +3/20000 train_loss: 10.8179 total_loss: 10.8179 train_time: 0.0m tok/s: 7012024 +4/20000 train_loss: 9.0448 total_loss: 9.0448 train_time: 0.0m tok/s: 6985747 +5/20000 train_loss: 7.7871 total_loss: 7.7871 train_time: 0.0m tok/s: 6963903 +500/20000 train_loss: 3.0489 total_loss: 3.0489 train_time: 1.0m tok/s: 6778298 +1000/20000 train_loss: 2.8963 total_loss: 2.8963 train_time: 1.9m tok/s: 6771907 +1500/20000 train_loss: 2.8779 total_loss: 2.8779 train_time: 2.9m tok/s: 6776683 +2000/20000 train_loss: 2.7783 total_loss: 2.7783 train_time: 3.9m tok/s: 6777486 +2500/20000 train_loss: 2.8096 total_loss: 2.8096 train_time: 4.8m tok/s: 6778357 +3000/20000 train_loss: 2.6859 total_loss: 2.6859 train_time: 5.8m tok/s: 6779349 +3500/20000 train_loss: 2.6651 total_loss: 2.6651 train_time: 6.8m tok/s: 6780677 +4000/20000 train_loss: 2.6484 total_loss: 2.6484 train_time: 7.7m tok/s: 6781509 +4000/20000 val_loss: 2.6407 val_bpb: 1.1334 +codebook_penalty:enabled step:4425/20000 start_frac:0.900 frac:0.900 weight:4.00000 mode:fast_direction_target refresh_every:64 target_tensors:78 target_weights:37486592 +4500/20000 train_loss: 2.6035 total_loss: 2.7927 train_time: 8.8m tok/s: 6720393 codebook_penalty_loss: 0.18923 +4822/20000 val_loss: 2.5760 val_bpb: 1.1056 +stopping_early: wallclock_cap train_time: 570074ms step: 4822/20000 +peak memory allocated: 30446 MiB reserved: 31150 MiB +ema:applying EMA weights +pre-codebook post-ema val_loss:2.57351430 val_bpb:1.10450966 eval_time:1982ms +Serialized model: 155502379 bytes +Code size: 121768 bytes +codebook:collecting calibration hessians... +codebook:collected 78 hessians in 9.7s +codebook:fit target_tensors:78 target_weights:37486592 coverage:93.3113% rel_mse:3.521538e-02 target_bpw:4.0000 +codebook:fit_time:17.9s +Serialized model codebook+brotli: 15742182 bytes (payload_before_torchsave:17625458 bytes) +Codebook coverage:93.3113% target_bpw:3.1705 effective_payload_bpw_all_weights:3.5099 +Codebook scale_codec:log scale_bits:8 entropy_weight:0.0000 +Codebook outliers frac:0.000000 max_count:2048 outlier_weights:159744 outlier_bytes:798876 +Fallback payloads int8_weights:2621440 int8_bytes:2637824 passthrough_weights:65643 passthrough_bytes:131286 +codebook:entropy_summary tensors:78 idx_compressed_bytes:9371611 scale_compressed_bytes:3584882 combined_compressed_bytes:12945121 +codebook:tensor name:blocks.6.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:108418 combined_zip:370214 idx_H:15.360bits scale_H:10.873bits scale_dH:11.780bits scale_store:8.0bits scale_rel_delta:0.34297 scale_corr:0.00064 +codebook:tensor name:blocks.3.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:108110 combined_zip:369928 idx_H:15.438bits scale_H:10.772bits scale_dH:11.659bits scale_store:8.0bits scale_rel_delta:0.32160 scale_corr:-0.00469 +codebook:tensor name:blocks.1.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:107801 combined_zip:369600 idx_H:15.435bits scale_H:10.746bits scale_dH:11.634bits scale_store:8.0bits scale_rel_delta:0.31628 scale_corr:-0.00561 +codebook:tensor name:blocks.2.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:107093 combined_zip:368922 idx_H:15.437bits scale_H:10.768bits scale_dH:11.658bits scale_store:8.0bits scale_rel_delta:0.32119 scale_corr:-0.00747 +codebook:tensor name:blocks.11.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:107067 combined_zip:369017 idx_H:15.463bits scale_H:10.759bits scale_dH:11.653bits scale_store:8.0bits scale_rel_delta:0.31857 scale_corr:0.00201 +codebook:tensor name:blocks.0.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:106771 combined_zip:368690 idx_H:15.460bits scale_H:10.752bits scale_dH:11.636bits scale_store:8.0bits scale_rel_delta:0.31739 scale_corr:-0.00359 +codebook:tensor name:blocks.8.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:106346 combined_zip:368253 idx_H:15.428bits scale_H:10.820bits scale_dH:11.716bits scale_store:8.0bits scale_rel_delta:0.33150 scale_corr:-0.00060 +codebook:tensor name:blocks.5.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:106284 combined_zip:368089 idx_H:15.403bits scale_H:10.830bits scale_dH:11.733bits scale_store:8.0bits scale_rel_delta:0.33326 scale_corr:-0.00389 +codebook:tensor name:blocks.9.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:106152 combined_zip:368089 idx_H:15.449bits scale_H:10.802bits scale_dH:11.704bits scale_store:8.0bits scale_rel_delta:0.32790 scale_corr:0.00064 +codebook:tensor name:blocks.4.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:106071 combined_zip:367872 idx_H:15.415bits scale_H:10.801bits scale_dH:11.696bits scale_store:8.0bits scale_rel_delta:0.32807 scale_corr:-0.00460 +codebook:tensor name:blocks.7.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:105352 combined_zip:367190 idx_H:15.382bits scale_H:10.846bits scale_dH:11.753bits scale_store:8.0bits scale_rel_delta:0.33804 scale_corr:-0.00391 +codebook:tensor name:blocks.10.mlp.proj.weight idx_raw:262144 idx_zip:262154 scale_raw:131072 scale_zip:105299 combined_zip:367268 idx_H:15.459bits scale_H:10.784bits scale_dH:11.683bits scale_store:8.0bits scale_rel_delta:0.32552 scale_corr:-0.00281 +codebook:focus name:blocks.6.mlp.proj.weight idx_unique:51231 idx_H:15.3595bits idx_byte_H:7.9804bits idx_raw:262144 idx_entropy_bound:251650.5 idx_zip:262154 +codebook:focus name:blocks.6.mlp.proj.weight scale_unique:3247 scale_H:10.8732bits scale_dH:11.7796bits scale_byte_H:6.5936bits scale_store:8.0bits scale_raw:131072 scale_entropy_bound:178146.6 scale_zip:108418 +codebook:focus name:blocks.6.mlp.proj.weight combined_zip:370214 scale_delta_mean_abs:0.017422 scale_rel_delta:0.342973 scale_adjacent_corr:0.000635 +Total submission size codebook+brotli: 15863950 bytes +final_codebook_roundtrip val_loss:2.86139953 val_bpb:1.22806523 eval_time:20566ms +final_codebook_sliding_window val_loss:2.82183456 val_bpb:1.21108460 eval_time:98458ms