diff --git a/pg_novel_ideas.md b/pg_novel_ideas.md index 8d2eb55b74..5c735e1f34 100644 --- a/pg_novel_ideas.md +++ b/pg_novel_ideas.md @@ -914,6 +914,20 @@ class ShallowRecurrentGPT: ### Complementary Training + EngramLite + BackoffNgramMixer: The Integrated Stack +**Status: IMPLEMENTED** in `train_gpt_mlx_kl.py` (April 2026). + +**Env vars for full moonshot run (8×H100):** +``` +ENGRAM_LITE_ENABLED=1 COMPLEMENT_ALPHA=0.5 NGRAM_MIXER_ENABLED=1 NGRAM_ALPHA=0.25 NGRAM_MAX_ORDER=4 +``` + +**Smoke test (M1, 100 steps):** +``` +RUN_ID=moonshot_test ITERATIONS=100 TRAIN_BATCH_TOKENS=8192 VAL_LOSS_EVERY=0 VAL_BATCH_SIZE=8192 \ +WARMUP_STEPS=3 ENGRAM_LITE_ENABLED=1 COMPLEMENT_ALPHA=0.5 NGRAM_MIXER_ENABLED=1 EVAL_MODE=standard \ +python3 train_gpt_mlx_kl.py +``` + **Why this is the single best bet nobody has fully combined:** The top competition results reveal three independent discoveries that, when properly integrated, form a system greater than the sum of its parts: diff --git a/train_gpt_mlx_kl.py b/train_gpt_mlx_kl.py index 9cf931ef01..ea63cc4a8d 100644 --- a/train_gpt_mlx_kl.py +++ b/train_gpt_mlx_kl.py @@ -110,6 +110,17 @@ class Hyperparameters: use_gptq_lite: bool = bool(int(os.environ.get("USE_GPTQ_LITE", "1"))) grad_clip_norm: float = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + # Moonshot stack: EngramLite + SkipGram + Complementary Training + BackoffNgramMixer + engram_lite_enabled: bool = bool(int(os.environ.get("ENGRAM_LITE_ENABLED", "0"))) + engram_hash_size: int = int(os.environ.get("ENGRAM_HASH_SIZE", 2048)) + engram_embed_dim: int = int(os.environ.get("ENGRAM_EMBED_DIM", 128)) + engram_n_heads: int = int(os.environ.get("ENGRAM_N_HEADS", 2)) + skipgram_hash_size: int = int(os.environ.get("SKIPGRAM_HASH_SIZE", 0)) + complement_alpha: float = float(os.environ.get("COMPLEMENT_ALPHA", 0.0)) + ngram_mixer_enabled: bool = bool(int(os.environ.get("NGRAM_MIXER_ENABLED", "0"))) + ngram_alpha: float = float(os.environ.get("NGRAM_ALPHA", 0.25)) + ngram_max_order: int = int(os.environ.get("NGRAM_MAX_ORDER", 4)) + out_dir: str = os.environ.get("OUT_DIR", "logs") @property @@ -140,6 +151,7 @@ def lr_mul(self, step: int, elapsed_ms: float) -> float: "attn_scale", "attn_scales", "mlp_scale", "mlp_scales", "resid_mix", "resid_mixes", "q_gain", "skip_weight", "skip_weights", "smear", # SmearGate gate parameter → scalar optimizer + "gate_bias", # EngramLiteEmbedding learned gate → scalar optimizer ) # ============================================================================ @@ -224,6 +236,34 @@ def next_batch(self, batch_tokens: int, seq_len: int) -> tuple[mx.array, mx.arra y = chunk[1:].reshape(-1, seq_len) return mx.array(x, dtype=mx.int32), mx.array(y, dtype=mx.int32) +def build_bigram_stats(data_path: str, vocab_size: int, max_tokens: int = 5_000_000) -> np.ndarray: + """Pre-compute bigram transition probabilities P(next|prev) from training data. + + Reads up to *max_tokens* tokens from the first available training shards. + Returns a (vocab_size, vocab_size) float32 matrix normalised per row. + Uses np.bincount for speed; Laplace-smoothed so every entry is > 0. + """ + train_files = sorted(glob.glob(f"{data_path}/fineweb_train_*.bin")) + if not train_files: + raise FileNotFoundError(f"No training files in {data_path}") + counts = np.zeros(vocab_size * vocab_size, dtype=np.float32) + tokens_read = 0 + for path in train_files: + if tokens_read >= max_tokens: + break + header = np.fromfile(path, dtype=" mx.array: pad = mx.zeros((tokens.shape[0], 1, bigram_emb.shape[-1]), dtype=bigram_emb.dtype) return mx.concatenate([pad, bigram_emb], axis=1) # (B, T, dim) +# ============================================================================ +# MOONSHOT INNOVATION: EngramLiteEmbedding +# Gated multi-head bigram + trigram hash logit bias. Inspired by #1089. +# Replaces BigramHash when ENGRAM_LITE_ENABLED=1. +# Key fix over raw TrigramHash (#609): per-order gate starts suppressed +# (gate_bias=-2 → sigmoid≈0.12) so the model learns when to trust hashes. +# ============================================================================ +class EngramLiteEmbedding(nn.Module): + """Multi-head hashed bigram+trigram embeddings with learned gating. + + For each n-gram order (2=bigram, 3=trigram): + - K hash functions map context to table indices + - Mean-pool across K heads (reduces collision noise) + - Sum contributions from both orders + A single learned scalar gate (sigmoid, initialised suppressed) scales + the overall output before projection to vocab_size. + + Parameter budget (default hash_size=2048, embed_dim=128, n_heads=2): + bigram_tables : 2 × 2048×128 = 524K + trigram_tables : 2 × 2048×128 = 524K + proj : 128×V = 131K (V=vocab_size=1024) + gate_bias : scalar ≈ 0 + Total ≈ 1.2M params ≈ 0.9MB in int6+zstd + """ + _PRIMES = [31337, 59999, 73721, 97531] # Distinct primes for independent hash heads + + def __init__(self, hash_size: int = 2048, embed_dim: int = 128, + output_dim: int = 1024, n_heads: int = 2): + super().__init__() + self.hash_size = hash_size + self.embed_dim = embed_dim + self.n_heads = n_heads + self._primes = self._PRIMES[:n_heads] + # Separate tables for bigrams and trigrams; list → tracked by MLX + self.bigram_tables = [nn.Embedding(hash_size, embed_dim) for _ in range(n_heads)] + self.trigram_tables = [nn.Embedding(hash_size, embed_dim) for _ in range(n_heads)] + for t in self.bigram_tables + self.trigram_tables: + t.weight = t.weight * 0.01 + self.proj = nn.Linear(embed_dim, output_dim, bias=False) + # gate_bias: scalar; sigmoid(-2)≈0.12 — output starts nearly suppressed, + # model learns to open it. Matches CONTROL_TENSOR_NAME_PATTERNS so it + # goes to Adam scalar optimizer rather than Muon. + self.gate_bias = mx.full((1,), -2.0, dtype=mx.float32) + + def _bigram_idx(self, tokens: mx.array, prime: int) -> mx.array: + t_prev = tokens[:, :-1] + t_curr = tokens[:, 1:] + return mx.remainder(t_prev * prime + t_curr, self.hash_size) + + def _trigram_idx(self, tokens: mx.array, prime: int) -> mx.array: + # Horner's method with intermediate modulo to stay in int32 + t0 = tokens[:, :-2] + t1 = tokens[:, 1:-1] + t2 = tokens[:, 2:] + h = mx.remainder(t0 * prime + t1, self.hash_size) + return mx.remainder(h * prime + t2, self.hash_size) + + def __call__(self, tokens: mx.array) -> mx.array: + """tokens: (B, T) int32 → logit bias: (B, T, output_dim)""" + B, T = tokens.shape + accum = mx.zeros((B, T, self.embed_dim), dtype=mx.bfloat16) + + # Bigram (valid from position 1) + head_embs = [self.bigram_tables[hi](self._bigram_idx(tokens, self._primes[hi])) + for hi in range(self.n_heads)] + bigram_emb = sum(head_embs) / self.n_heads # (B, T-1, E) + pad1 = mx.zeros((B, 1, self.embed_dim), dtype=bigram_emb.dtype) + accum = accum + mx.concatenate([pad1, bigram_emb], axis=1) + + # Trigram (valid from position 2) + if T > 2: + head_embs = [self.trigram_tables[hi](self._trigram_idx(tokens, self._primes[hi])) + for hi in range(self.n_heads)] + trigram_emb = sum(head_embs) / self.n_heads # (B, T-2, E) + pad2 = mx.zeros((B, 2, self.embed_dim), dtype=trigram_emb.dtype) + accum = accum + mx.concatenate([pad2, trigram_emb], axis=1) + + gate = mx.sigmoid(self.gate_bias).astype(accum.dtype) + return self.proj(accum) * gate # (B, T, output_dim) + +# ============================================================================ +# MOONSHOT INNOVATION: SkipGramHashEmbedding +# Hash embedding using non-adjacent token pairs as context. +# Captures "template" patterns where intervening tokens vary (HTML, code, etc.) +# Enabled when SKIPGRAM_HASH_SIZE > 0. +# ============================================================================ +class SkipGramHashEmbedding(nn.Module): + """Logit bias from hashed non-contiguous token pairs. + + Each skip pattern (-a, -b) hashes tokens[t-a] and tokens[t-b] together + to produce an additive logit correction at position t. + Zero overlap with BigramHash / EngramLite (different context positions). + """ + _DEFAULT_PATTERNS = [(-1, -3), (-1, -5), (-2, -4)] + + def __init__(self, hash_size: int = 4096, output_dim: int = 1024, + skip_patterns: list | None = None): + super().__init__() + self.hash_size = hash_size + self.output_dim = output_dim + self.skip_patterns = skip_patterns or self._DEFAULT_PATTERNS + # List of embedding tables — one per pattern; list tracked by MLX + self.tables = [nn.Embedding(hash_size, output_dim) for _ in self.skip_patterns] + for t in self.tables: + t.weight = t.weight * 0.01 + + def __call__(self, tokens: mx.array) -> mx.array: + """tokens: (B, T) int32 → logit bias: (B, T, output_dim)""" + B, T = tokens.shape + output = mx.zeros((B, T, self.output_dim), dtype=mx.bfloat16) + prime = 31337 + for pi, pattern in enumerate(self.skip_patterns): + valid_start = abs(min(pattern)) # positions 0..valid_start-1 are zero-padded + if T <= valid_start: + continue + # Build hash from all offsets in the pattern + h = mx.zeros((B, T - valid_start), dtype=mx.int32) + for offset in pattern: + s = valid_start + offset # ≥ 0 by construction + tok_slice = tokens[:, s: s + (T - valid_start)] + h = mx.remainder(h * prime + tok_slice, self.hash_size) + emb = self.tables[pi](h) # (B, T-valid_start, output_dim) + pad = mx.zeros((B, valid_start, self.output_dim), dtype=emb.dtype) + output = output + mx.concatenate([pad, emb], axis=1) + return output + +# ============================================================================ +# MOONSHOT INNOVATION: BackoffNgramMixer +# Causal, fully-normalized n-gram language model built at eval time. +# Zero artifact cost (no weights stored — built from already-scored tokens). +# Correct per competition rules: full-vocab normalized, backward-looking only. +# Enabled when NGRAM_MIXER_ENABLED=1. +# ============================================================================ +class BackoffNgramMixer: + """Causal BackoffNgram model with linear interpolation. + + For each validation token at position t: + 1. Query all n-gram orders using tokens[t-order .. t-1] as context. + 2. Linearly interpolate from unigram (order=1) up to max_order. + weight(order) = count / (count + 5) — simple additive discount. + 3. Mix with neural distribution: P = (1-alpha)*P_neural + alpha*P_ngram. + 4. Score: NLL = -log P[true_token]. + 5. Update count tables with true token (strictly causal). + + alpha_mode='fixed' → constant NGRAM_ALPHA mixing weight + alpha_mode='entropy' → higher alpha when neural model is uncertain + (alpha range [0.15, 0.60] linearly mapped from entropy) + """ + def __init__(self, vocab_size: int = 1024, max_order: int = 4, + alpha: float = 0.25, alpha_mode: str = "fixed"): + from collections import defaultdict + self.vocab_size = vocab_size + self.max_order = max_order + self.alpha = alpha + self.alpha_mode = alpha_mode + # counts[order][ctx_hash] = np.ndarray of shape (vocab_size,) + # total[order][ctx_hash] = float + self.counts = [defaultdict(lambda: np.zeros(vocab_size, dtype=np.float32)) + for _ in range(max_order + 1)] + self.totals = [defaultdict(float) for _ in range(max_order + 1)] + + def _hash_ctx(self, ctx: list[int]) -> int: + h = 0 + for t in ctx: + h = (h * 31337 + int(t)) & 0x7FFFFFFF + return h + + def _ngram_probs(self, context: list[int]) -> np.ndarray: + """Compute interpolated n-gram probability distribution over vocab.""" + probs = np.ones(self.vocab_size, dtype=np.float64) / self.vocab_size + for order in range(1, self.max_order + 1): + if len(context) < order: + break + ctx = context[-order:] + h = self._hash_ctx(ctx) + cnt = self.counts[order][h] + tot = self.totals[order][h] + if tot > 0: + lam = tot / (tot + 5.0) + order_probs = (cnt + 1e-10) / (tot + 1e-10 * self.vocab_size) + order_probs = order_probs / order_probs.sum() + probs = (1.0 - lam) * probs + lam * order_probs + return (probs / probs.sum()).astype(np.float64) + + def _compute_alpha(self, neural_logits: np.ndarray) -> float: + if self.alpha_mode == "fixed": + return self.alpha + lmax = neural_logits.max() + p = np.exp(neural_logits - lmax) + p /= p.sum() + ent = -np.sum(p * np.log2(p + 1e-12)) + max_ent = math.log2(self.vocab_size) + return 0.15 + 0.45 * (ent / max_ent) + + def score_and_update(self, context: list[int], true_token: int, + neural_logits: np.ndarray) -> float: + """Score *true_token* under the mixed distribution and update caches. + + Must be called in causal order (position 0, 1, 2, …). + Returns the negative log-probability (NLL) contribution. + """ + ngram_p = self._ngram_probs(context) + alpha = self._compute_alpha(neural_logits) + # Neural model: convert raw logits → probabilities + lmax = neural_logits.max() + neural_p = np.exp(neural_logits - lmax) + neural_p /= neural_p.sum() + mixed = (1.0 - alpha) * neural_p + alpha * ngram_p + mixed /= mixed.sum() + nll = -math.log(float(mixed[true_token]) + 1e-30) + # Update caches AFTER scoring (strict causality) + for order in range(1, self.max_order + 1): + if len(context) >= order: + ctx = context[-order:] + h = self._hash_ctx(ctx) + self.counts[order][h][true_token] += 1.0 + self.totals[order][h] += 1.0 + return nll + # ============================================================================ # INNOVATION: SmearGate — blend each token embedding with previous token's # Technique: @unnir (PR #102/#135). Gate initialized to 3.0 → sigmoid≈0.95 pass-through. @@ -414,7 +673,10 @@ def __init__(self, vocab_size, num_layers, dim, num_heads, num_kv_heads, mlp_mult, logit_chunk_tokens, logit_softcap, rope_base, tied_embed_init_std, qk_gain_init, bigram_hash_size, use_ortho_init, rope_dims: int = 0, xsa_last_n: int = 0, - use_ln_scale: bool = True, smear_enabled: bool = True): + use_ln_scale: bool = True, smear_enabled: bool = True, + engram_lite_enabled: bool = False, engram_hash_size: int = 2048, + engram_embed_dim: int = 128, engram_n_heads: int = 2, + skipgram_hash_size: int = 0): super().__init__() self.logit_chunk_tokens = logit_chunk_tokens self.logit_softcap = logit_softcap @@ -440,8 +702,19 @@ def __init__(self, vocab_size, num_layers, dim, num_heads, num_kv_heads, ] self.final_norm = RMSNormNoWeight() - # INNOVATION: BigramHash on logits (None when bigram_hash_size=0) - self.bigram_hash = BigramHashEmbedding(bigram_hash_size, vocab_size) if bigram_hash_size > 0 else None + # Hash-based logit bias: EngramLite (moonshot) or plain BigramHash + if engram_lite_enabled: + self.bigram_hash = EngramLiteEmbedding( + hash_size=engram_hash_size, embed_dim=engram_embed_dim, + output_dim=vocab_size, n_heads=engram_n_heads) + elif bigram_hash_size > 0: + self.bigram_hash = BigramHashEmbedding(bigram_hash_size, vocab_size) + else: + self.bigram_hash = None + + # Skip-gram hash logit bias (moonshot; None when skipgram_hash_size=0) + self.skipgram_hash = (SkipGramHashEmbedding(skipgram_hash_size, vocab_size) + if skipgram_hash_size > 0 else None) # Zero-init output projections for b in self.blocks: @@ -472,6 +745,16 @@ def softcap(self, logits: mx.array) -> mx.array: c = self.logit_softcap return c * mx.tanh(logits / c) + def _apply_hash_biases(self, logits: mx.array, input_ids: mx.array) -> mx.array: + """Add BigramHash / EngramLite and SkipGram logit biases if enabled.""" + if self.bigram_hash is not None: + bias = self.bigram_hash(input_ids) + logits = logits + bias.reshape(-1, bias.shape[-1]).astype(logits.dtype) + if self.skipgram_hash is not None: + bias = self.skipgram_hash(input_ids) + logits = logits + bias.reshape(-1, bias.shape[-1]).astype(logits.dtype) + return logits + def __call__(self, input_ids: mx.array) -> mx.array: x = rms_norm(self.tok_emb(input_ids).astype(COMPUTE_DTYPE)) if self.smear is not None: @@ -492,17 +775,34 @@ def __call__(self, input_ids: mx.array) -> mx.array: def loss(self, input_ids: mx.array, target_ids: mx.array) -> mx.array: x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1]) y = target_ids.reshape(-1) - logits = x @ self.tok_emb.weight.astype(x.dtype).T logits = self.softcap(logits) + logits = self._apply_hash_biases(logits, input_ids) + return nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="mean") - # Add BigramHash logit bias (skipped when bigram_hash is None) - if self.bigram_hash is not None: - bigram_bias = self.bigram_hash(input_ids) # (B, T, vocab) - bigram_bias = bigram_bias.reshape(-1, bigram_bias.shape[-1]) - logits = logits + bigram_bias.astype(logits.dtype) + def complementary_loss(self, input_ids: mx.array, target_ids: mx.array, + bigram_probs: mx.array, alpha: float) -> mx.array: + """Cross-entropy down-weighted for tokens easily predicted by bigrams. - return nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="mean") + Tokens with high P_bigram(target | prev) receive a lower training weight, + pushing the neural model to specialise on what n-gram models cannot predict. + + bigram_probs: (V, V) float32 — pre-computed P(next | prev), Laplace-smoothed. + alpha: complement strength in [0, 1]. 0 = standard CE. + """ + x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1]) + y = target_ids.reshape(-1) + logits = x @ self.tok_emb.weight.astype(x.dtype).T + logits = self.softcap(logits) + logits = self._apply_hash_biases(logits, input_ids) + ce_per_token = nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="none") + V = bigram_probs.shape[1] + flat_idx = input_ids.reshape(-1) * V + y + p_bigram = bigram_probs.reshape(-1)[flat_idx] + # Floor at 0.1 preserves at least 10% gradient for easy tokens; cap at 1.0 + weights = mx.clip(1.0 - alpha * p_bigram, 0.1, 1.0) + weights = weights / weights.mean() # normalise so effective LR is unchanged + return (ce_per_token * weights).mean() def token_losses(self, input_ids: mx.array, target_ids: mx.array) -> mx.array: """Return (B, T) per-token NLL — used for sliding-window eval.""" @@ -511,12 +811,19 @@ def token_losses(self, input_ids: mx.array, target_ids: mx.array) -> mx.array: y = target_ids.reshape(-1) logits = x @ self.tok_emb.weight.astype(x.dtype).T logits = self.softcap(logits) - if self.bigram_hash is not None: - bigram_bias = self.bigram_hash(input_ids).reshape(-1, self.tok_emb.weight.shape[0]) - logits = logits + bigram_bias.astype(logits.dtype) + logits = self._apply_hash_biases(logits, input_ids) nll = nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="none") return nll.reshape(B, T) + def token_logits(self, input_ids: mx.array) -> mx.array: + """Return (B, T, V) raw logits — used by BackoffNgramMixer eval.""" + B, T = input_ids.shape + x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1]) + logits = x @ self.tok_emb.weight.astype(x.dtype).T + logits = self.softcap(logits) + logits = self._apply_hash_biases(logits, input_ids) + return logits.reshape(B, T, -1) + # ============================================================================ # OPTIMIZERS (same structure as baseline) # ============================================================================ @@ -547,20 +854,23 @@ def step(self, params, grads, step, lr_mul): return out class SplitOptimizers: + # Parameter prefixes managed by this optimizer (blocks + all hash embeddings) + _MANAGED_PREFIXES = ("blocks.", "bigram_hash.", "skipgram_hash.") + def __init__(self, model, args): self.args = args params = dict(tree_flatten(model.parameters())) self.embed_key = "tok_emb.weight" self.matrix_keys = [ k for k, p in params.items() - if (k.startswith("blocks.") or k.startswith("bigram_hash.")) + if any(k.startswith(pfx) for pfx in self._MANAGED_PREFIXES) and p.ndim == 2 and not any(pat in k for pat in CONTROL_TENSOR_NAME_PATTERNS) ] self.scalar_keys = [ k for k, p in params.items() if k == "skip_weights" or ( - (k.startswith("blocks.") or k.startswith("bigram_hash.")) + any(k.startswith(pfx) for pfx in self._MANAGED_PREFIXES) and (p.ndim < 2 or any(pat in k for pat in CONTROL_TENSOR_NAME_PATTERNS)) ) ] @@ -1076,6 +1386,99 @@ def lora_loss(): return val_loss, val_bpb +# ============================================================================ +# MOONSHOT: BackoffNgramMixer sliding-window eval +# Processes the validation set in one sequential pass, scoring each token +# with the mixed (neural + n-gram) distribution. Windows are processed in +# order so the n-gram cache grows causally; each token is scored exactly once. +# ============================================================================ +def eval_val_sliding_ngram(args, model, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + log_fn=None): + """Sliding-window eval with BackoffNgramMixer applied at each scored token. + + For each validation token at position t (scored exactly once): + 1. Compute neural logits via the standard sliding-window batching. + 2. Pass logits and causal context to BackoffNgramMixer.score_and_update. + 3. The mixer updates its n-gram cache with the true token (causal, AFTER score). + + The neural-model batching is identical to eval_val_sliding, preserving + the long-context benefit of the sliding window. + """ + seq_len = args.eval_seq_len + stride = args.eval_stride + batch_seqs = args.eval_batch_seqs + total_tokens = val_tokens.size - 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + + mixer = BackoffNgramMixer( + vocab_size=args.vocab_size, + max_order=args.ngram_max_order, + alpha=args.ngram_alpha, + alpha_mode="fixed", + ) + + loss_sum = 0.0 + token_count = 0.0 + byte_count = 0.0 + + model.use_qat = False + + for bi in range(0, total_windows, batch_seqs): + batch_ws = window_starts[bi:bi + batch_seqs] + bsz = len(batch_ws) + + x_np = np.zeros((bsz, seq_len), dtype=np.int32) + y_np = np.zeros((bsz, seq_len), dtype=np.int32) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + x_np[i, :wlen] = val_tokens[ws:end] + y_np[i, :wlen] = val_tokens[ws + 1:end + 1] + + # Get full (B, T, V) logits from neural model + x = mx.array(x_np) + logits_mx = model.token_logits(x) # (B, T, V) + mx.eval(logits_mx) + logits_np = np.array(logits_mx.astype(mx.float32)) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + for t in range(s, wlen): + pos = ws + t + true_token = int(y_np[i, t]) + # Causal context: tokens before position pos (limited to max_order) + ctx_start = max(0, pos - args.ngram_max_order) + ctx = [int(v) for v in val_tokens[ctx_start: pos + 1]] + neural_lg = logits_np[i, t, :] # (V,) raw logits + nll = mixer.score_and_update(ctx, true_token, neural_lg) + loss_sum += nll + token_count += 1.0 + tgt = y_np[i, t:t+1] + prev = x_np[i, t:t+1] + tb = base_bytes_lut[tgt].astype(np.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).astype(np.float64) + byte_count += float(tb.sum()) + + if log_fn and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, total_windows) + pct = done / total_windows * 100 + rbpb = 0.0 + if token_count > 0: + rbpb = (loss_sum / token_count) / math.log(2.0) * (token_count / byte_count) + log_fn(f"ngram_sliding_eval [{pct:5.1f}%] {done}/{total_windows} " + f"windows running_bpb={rbpb:.6f}") + + val_loss = loss_sum / token_count + val_bpb = (val_loss / math.log(2.0)) * (token_count / byte_count) + return val_loss, val_bpb + # ============================================================================ # MAIN # ============================================================================ @@ -1100,6 +1503,15 @@ def log(msg, console=True): val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts(sp, args.vocab_size) + # Complementary Training: pre-compute bigram stats if requested + bigram_probs_mx = None + if args.complement_alpha > 0.0: + log(f"complement_training: building bigram stats (alpha={args.complement_alpha})...") + bp_np = build_bigram_stats(args.data_path, args.vocab_size) + bigram_probs_mx = mx.array(bp_np, dtype=mx.float32) + mx.eval(bigram_probs_mx) + log("complement_training: bigram stats ready") + mx.random.seed(args.seed) train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) @@ -1112,6 +1524,11 @@ def log(msg, console=True): use_ortho_init=args.use_ortho_init, rope_dims=args.rope_dims, xsa_last_n=args.xsa_last_n, use_ln_scale=args.ln_scale_enabled, smear_enabled=args.smear_enabled, + engram_lite_enabled=args.engram_lite_enabled, + engram_hash_size=args.engram_hash_size, + engram_embed_dim=args.engram_embed_dim, + engram_n_heads=args.engram_n_heads, + skipgram_hash_size=args.skipgram_hash_size, ) opt = SplitOptimizers(model, args) @@ -1120,11 +1537,25 @@ def log(msg, console=True): # SWA buffer — starts at 60% of training when USE_SWA=1 swa = None - compiled_loss = mx.compile(lambda x, y: model.loss(x, y), inputs=model.state, outputs=model.state) - compiled_loss_and_grad = mx.compile( - nn.value_and_grad(model, lambda x, y: model.loss(x, y)), - inputs=model.state, outputs=model.state, - ) + # Build compiled loss functions. The complementary training loss is a + # closure that captures bigram_probs_mx as a compile-time constant (it is + # not part of model.state, so MLX treats it as a static input). + def _make_compiled_fns(): + if bigram_probs_mx is not None: + def _train_loss(x, y): + return model.complementary_loss(x, y, bigram_probs_mx, args.complement_alpha) + else: + def _train_loss(x, y): + return model.loss(x, y) + c_loss = mx.compile( + lambda x, y: model.loss(x, y), + inputs=model.state, outputs=model.state) + c_lag = mx.compile( + nn.value_and_grad(model, _train_loss), + inputs=model.state, outputs=model.state) + return c_loss, c_lag + + compiled_loss, compiled_loss_and_grad = _make_compiled_fns() n_params = sum(int(np.prod(p.shape)) for _, p in tree_flatten(model.parameters())) log(f"run_id:{args.run_id}") @@ -1136,6 +1567,9 @@ def log(msg, console=True): f"ln_scale={args.ln_scale_enabled} xsa_last_n={args.xsa_last_n} xsa_layers={xsa_layers} " f"gptq_lite={args.use_gptq_lite} ttt={args.ttt_enabled} eval_mode={args.eval_mode} " f"use_swa={args.use_swa} swa_decay={args.swa_decay}") + log(f"moonshot: engram_lite={args.engram_lite_enabled}(hash={args.engram_hash_size},dim={args.engram_embed_dim},heads={args.engram_n_heads}) " + f"skipgram_hash={args.skipgram_hash_size} complement_alpha={args.complement_alpha} " + f"ngram_mixer={args.ngram_mixer_enabled}(alpha={args.ngram_alpha},order={args.ngram_max_order})") log(f"iterations:{args.iterations} train_batch_tokens:{args.train_batch_tokens} " f"grad_accum:{args.grad_accum_steps} seq_len:{args.train_seq_len}") log(f"optimizer: muon_keys:{len(opt.matrix_keys)} scalar_keys:{len(opt.scalar_keys)}") @@ -1187,10 +1621,7 @@ def log(msg, console=True): if _avg is not None: model.update(tree_unflatten(list(saved_state.items()))) - compiled_loss = mx.compile(lambda x, y: model.loss(x, y), inputs=model.state, outputs=model.state) - compiled_loss_and_grad = mx.compile( - nn.value_and_grad(model, lambda x, y: model.loss(x, y)), - inputs=model.state, outputs=model.state) + compiled_loss, compiled_loss_and_grad = _make_compiled_fns() t0 = time.perf_counter() @@ -1209,12 +1640,7 @@ def log(msg, console=True): _prev_use_qat = _new_use_qat if _new_use_qat: log(f"qat_started:step={step} lr_mul={lr_mul:.4f} — recompiling graph") - compiled_loss = mx.compile( - lambda x, y: model.loss(x, y), - inputs=model.state, outputs=model.state) - compiled_loss_and_grad = mx.compile( - nn.value_and_grad(model, lambda x, y: model.loss(x, y)), - inputs=model.state, outputs=model.state) + compiled_loss, compiled_loss_and_grad = _make_compiled_fns() # Initialize EMA after ema_start_frac; initialize SWA at 60% of iterations est_total = args.iterations @@ -1319,6 +1745,11 @@ def log(msg, console=True): log(f"final_eval_mode:ttt_sliding rank:{args.ttt_rank} lr:{args.ttt_lr} steps:{args.ttt_steps} stride:{args.eval_stride}") q_val_loss, q_val_bpb = eval_val_sliding_ttt(args, model, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, log_fn=log) + elif args.ngram_mixer_enabled: + log(f"final_eval_mode:sliding_ngram eval_seq_len:{args.eval_seq_len} stride:{args.eval_stride} " + f"batch_seqs:{args.eval_batch_seqs} ngram_alpha:{args.ngram_alpha} ngram_max_order:{args.ngram_max_order}") + q_val_loss, q_val_bpb = eval_val_sliding_ngram(args, model, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, log_fn=log) else: log(f"final_eval_mode:sliding_window eval_seq_len:{args.eval_seq_len} stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") q_val_loss, q_val_bpb = eval_val_sliding(args, model, val_tokens,