Skip to content

Non-record: Soft MoE Exploration — Dense Gating Fixes Sparse Router Collapse Under 16MB (WIP, val_bpb=1.1826)#660

Draft
HugoOchoaLP wants to merge 12 commits intoopenai:mainfrom
HugoOchoaLP:test2
Draft

Non-record: Soft MoE Exploration — Dense Gating Fixes Sparse Router Collapse Under 16MB (WIP, val_bpb=1.1826)#660
HugoOchoaLP wants to merge 12 commits intoopenai:mainfrom
HugoOchoaLP:test2

Conversation

@HugoOchoaLP
Copy link
Copy Markdown

Non-record: Soft MoE Under 16MB Parameter Constraint

Score: val_bpb = 1.1826 (11L, 8xH100, 600s — artifact 17.3MB, 10L version pending to fit under 16MB)

Summary

Systematic exploration of Mixture of Experts for parameter golf. Standard sparse MoE fails due to router collapse (98%/2% split), torch.compile incompatibility (3x slower steps), and parameter overhead. A dense "Soft MoE" variant fixes all three: learned soft gating over all experts eliminates collapse, enables torch.compile (636ms vs 2309ms/step), and achieves 1.1826 bpb vs 1.2244 baseline.

Experiment Results

Config Steps val_bpb ms/step Artifact Expert Balance
11L no MoE 138 3.26 794 17.5MB n/a
9L sparse MoE 52 3.86 2309 too big n/a
9L sparse, aux=0.1 89 3.31 1415 13.1MB 2%/98% collapsed
9L Soft MoE 189 3.25 636 14.6MB balanced
11L Soft MoE 8xH100 4704 1.1826 128 17.3MB balanced

Status

  • Sparse MoE experiments (negative result)
  • Soft MoE implementation and validation
  • Full 8xH100 run (11L, over 16MB)
  • 10L run fitting under 16MB
  • 3-seed statistical significance

HugoOchoaLP and others added 12 commits March 21, 2026 17:47
- train_gpt_rank1_int5mlp_swa.py: adds TrigramHashEmbedding module
  (TRIGRAM_HASH_BUCKETS env var, default 0=off) that hashes 3 consecutive
  tokens via polynomial hash into a learned embedding table, projected to
  model_dim and added to the embedding sum alongside BigramHash

- train_gpt_recurrent.py: based on rank1, adds NUM_LOOPS env var (default 1)
  for depth recurrence — loops the same set of unique layer weights multiple
  times per forward pass, freeing parameter budget for wider dims or larger
  embedding tables

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Based on rank openai#1 submission. Key changes:
- MoEMLP: num_experts small MLPs (hidden_dim = mlp_mult*dim/num_experts)
  so total params stay equal to a single full-size MLP
- Top-1 routing via argmax with straight-through estimator for gradient
- Switch Transformer load balancing aux loss (coef 0.01) to prevent collapse
- Dense computation (all experts on all tokens) for torch.compile compatibility
- Env vars: NUM_EXPERTS=4, TOP_K=1, MOE_AUX_LOSS_COEF=0.01
- All rank openai#1 techniques preserved: int5 MLP quant, BigramHash, SmearGate, SWA

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Key fixes from prior dense implementation:
- MoEMLP: sparse dispatch — only the selected expert runs per token.
  Experts are now full-size (hidden = mlp_mult * dim), so FLOPs per token
  match a single MLP while the model can truly specialize. Default 2 experts.
- SharedRoutedMLP: shared MLP (2x expansion) for all tokens + small routed
  experts (1x expansion) for specialization. More param-efficient for 16MB.
  Toggle with MOE_MODE=shared_routed (default: standard).
- Router weights (too small for Muon) now routed to AdamW scalar_params.
- torch.compile switched to dynamic=True for sparse dispatch compatibility.
- Expert utilization logging: router weight norms printed every log step
  to detect collapse (all tokens routing to one expert).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Replace SWA (uniform average) with EMA (decay=0.998): maintains a
  running exponential moving average during warmdown, giving more weight
  to recent checkpoints. Configurable via EMA_DECAY env var.
- NUM_LAYERS default: 10 → 11 (matches top leaderboard entries)
- WARMDOWN_ITERS default: 3000 → 3500

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Fix router quantization: add 'router' to FP16_KEEP_NAME_PATTERNS so
  router weights stay fp16 instead of being destroyed by int5 quant.
  Remove hardcoded 'blocks.8.attn.c_k'; dynamically keep last layer's
  c_k in fp16 regardless of num_layers.
- Fix utilization logging: store _last_routing in MoEMLP/SharedRoutedMLP
  and log actual token fracs per expert (detects collapse vs ~50/50 healthy).
- Selective MoE: MOE_START_LAYER env var (default num_layers//2) — only
  deeper layers get MoE, halving the extra param cost.
- Top-2 routing: TOP_K default changed to 2, with renormalized gates.
  More stable training than top-1. SharedRoutedMLP remains top-1.
- Parameter budget breakdown logged at startup with estimated artifact size.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
… from pruning

- SharedRoutedMLP: hardcode routed experts to MLP(dim, 1.0) — simpler
  and more predictable than max(mlp_mult-2.0, 1.0)
- Skip torch.compile for MoE (num_experts>1): sparse dispatch with nested
  top-k loops causes recompile thrashing; uncompiled is faster in practice.
  Non-MoE path keeps fullgraph=True for full speed.
- Magnitude pruning: add 'router' not in name guard — router weights are
  tiny and getting pruned would destroy routing decisions

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Dense Soft MoE runs ALL experts on ALL tokens with learned per-token gates
(softmax-blended). No routing collapse, no variable-size tensors.
With 2 experts each at mlp_mult/2 hidden dim, total params = 1 regular MLP.

- Re-enables torch.compile fullgraph=True for soft mode (compile-safe)
- Sparse modes (standard, shared_routed) still skip compile
- Gate weights added to AdamW scalar group (not Muon)
- Block.forward isinstance check updated to include SoftMoE

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Explores MoE under the 16MB constraint. Key findings:
- Sparse MoE fails: router collapse (98/2% split) even with 10x aux loss,
  and breaks torch.compile causing 2-3x step time regression
- Soft MoE (dense gating, all experts on all tokens) fixes both:
  no collapse, compile-friendly, 636ms/step vs 794ms baseline
- Best result: 1.1826 bpb on 11L config (17.3MB, over limit)
- 10L run pending to confirm fit under 16MB

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@HugoOchoaLP HugoOchoaLP reopened this Mar 25, 2026
@HugoOchoaLP HugoOchoaLP marked this pull request as draft March 25, 2026 00:06
@MatoTeziTanka
Copy link
Copy Markdown

Community Review — Non-record: Soft MoE Exploration — Dense Gating Fixes Sparse Router Collapse Under 16MB (WIP, val_bpb=1.1826)

BPB: 1.1826 | Compliance: LOOKS CLEAN — pure-neural submission, no TTT/SLOT/n-gram-cache

What I found in the code (head SHA 016985b6f8c0, file records/track_non_record_16mb/2026-03-24_SoftMoE_exploration/train_gpt.py):

Static code review found no TTT adaptation function, no SLOT optimization loop, no n-gram-cache class, and no pre-quant val-token fine-tune. The eval path uses the standard sliding-window stride-64 pattern. The submission is a pure-neural architecture iteration on the standard SP1024/SP4096/SP8192 baseline.

CPU smoke test (CT2038 proteus-engine, 2026-04-11): import OK in 0.07s, dim=512, layers=11, vocab=1024, code=65481 B, SMOKE_TEST_PASS

Verdict: LOOKS CLEAN.

Recommendation to @cocohearts @valerio-oai @0hq @yuzhougu-oai @notapplica: MERGE pending the usual record-track checks (3-seed validation, under-16MB artifact cap, ≤600s train + ≤600s eval on 8×H100 SXM). No compliance flags from the classification pass — this looks like a clean pure-neural iteration on the standard baseline.

Auto-classification caveat: this review was drafted by the AST-based classifier. If there's a non-standard eval mechanism (logit postprocessing, hedge mixing, etc.) that I missed because it's factored into a helper file or a non-standard function name, please flag it and I'll re-run the audit manually.


Reviewed by @MatoTeziTankaThe Agora. CPU smoke test (CT2038 proteus-engine, 2026-04-11): import OK in 0.07s, dim=512, layers=11, vocab=1024, code=65481 B, SMOKE_TEST_PASS. Classification via deterministic AST-based classify_prs.py (pattern bank derived from ~65 manually-reviewed PRs earlier in the 2026-04-11 sweep). This review was auto-drafted from a template and spot-checked before posting — if the template misread your code, please call it out so I can iterate the classifier.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants