-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Cautious Muon + SP4096 + Depth Recurrence — val_bpb 1.1604 (non-record) #1381
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,62 @@ | ||||||
| # Cautious Muon + SP4096 + Depth Recurrence + Parallel Residuals | ||||||
|
|
||||||
| **val_bpb = 1.1604** (3-seed mean, std = 0.0033) | ||||||
|
|
||||||
| ## Results | ||||||
|
|
||||||
| | Seed | val_bpb | val_loss | Artifact Size | | ||||||
| |------|---------|----------|---------------| | ||||||
| | 42 | 1.1568 | 2.6619 | 15,179,504 B | | ||||||
| | 314 | 1.1611 | 2.6717 | 15,173,470 B | | ||||||
| | 999 | 1.1634 | 2.6770 | 15,159,223 B | | ||||||
|
Comment on lines
+9
to
+13
|
||||||
| | **Mean** | **1.1604** | **2.6702** | **15,170,732 B** | | ||||||
|
|
||||||
| ## Key Technique: Cautious Muon (arXiv:2411.16085) | ||||||
|
|
||||||
| The primary modification is applying the Cautious optimizer principle to the Muon optimizer. After Newton-Schulz orthogonalization and MuonEq-R row normalization, the update is masked to only apply where the orthogonalized direction agrees with the raw gradient sign: | ||||||
|
|
||||||
| ```python | ||||||
| caution_mask = (g * raw_grad > 0).to(g.dtype) | ||||||
| g = g * caution_mask / caution_mask.mean().clamp_min(1e-3) | ||||||
| ``` | ||||||
|
||||||
|
|
||||||
| This filters out "stale" momentum directions that disagree with the current gradient, providing ~1.47x effective convergence per step with zero parameter overhead and no impact on artifact size. | ||||||
|
|
||||||
| ## Full Architecture Stack | ||||||
|
|
||||||
| Built on PR #1334 (aryanbhosale) with: | ||||||
| - **SP4096 BPE tokenizer** (from PR #1218, @clarkkev) | ||||||
| - **Depth recurrence** layers 4,5 (13 virtual layers from 11 physical, activated at step 3000) | ||||||
| - **Parallel residuals** from layer 7 (separate attn/MLP lanes with learnable merge) | ||||||
| - **MuonEq-R** row normalization before Newton-Schulz (arXiv:2603.28254) | ||||||
|
||||||
| - **MuonEq-R** row normalization before Newton-Schulz (arXiv:2603.28254) | |
| - **MuonEq-R** row normalization after Newton-Schulz orthogonalization (arXiv:2603.28254) |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,37 @@ | ||||||
| { | ||||||
| "author": "X-Abhishek-X", | ||||||
| "github_id": "X-Abhishek-X", | ||||||
| "name": "Cautious Muon + SP4096 + Depth Recurrence + Parallel Residuals", | ||||||
| "blurb": "Applies Cautious Muon (arXiv:2411.16085) to the Muon optimizer — masks Newton-Schulz updates where the orthogonalized direction disagrees with the raw gradient sign, providing ~1.47x effective convergence per step with zero parameter overhead. Built on PR #1334 (aryanbhosale) base with SP4096 vocabulary, depth recurrence (layers 4,5), parallel residuals (from layer 7), MuonEq-R, QK-Gain 5.0, and full GPTQ INT6 + Brotli compression. Mean val_bpb = 1.1604 (3 seeds, std = 0.0033).", | ||||||
| "date": "2026-04-05", | ||||||
|
||||||
| "date": "2026-04-05", | |
| "date": "2026-04-05T00:00:00Z", |
Copilot
AI
Apr 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most existing 10min_16mb submission.json files include a bytes_total field (e.g., records/track_10min_16mb/2026-03-22_11L_EMA_GPTQ-lite_warmdown3500_QAT015_1.1233/submission.json:9). This submission uses artifact_bytes_* instead; consider adding bytes_total (and optionally bytes_code) for consistency and easier leaderboard/tooling ingestion.
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,134 @@ | ||
| W0405 12:33:07.793000 58424 torch/distributed/run.py:803] | ||
| W0405 12:33:07.793000 58424 torch/distributed/run.py:803] ***************************************** | ||
| W0405 12:33:07.793000 58424 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. | ||
| W0405 12:33:07.793000 58424 torch/distributed/run.py:803] ***************************************** | ||
| Hyperparameters: | ||
| adam_eps: 1e-08 | ||
| adam_wd: 0.02 | ||
| beta1: 0.9 | ||
| beta2: 0.95 | ||
| cautious_muon: True | ||
| compressor: brotli | ||
| data_dir: ./data/ | ||
| datasets_dir: ./data/datasets/fineweb10B_sp4096 | ||
| distributed: True | ||
| ema_decay: 0.997 | ||
| embed_lr: 0.6 | ||
| embed_wd: 0.09 | ||
| embedding_dim: 512 | ||
| eval_seq_len: 2048 | ||
| eval_stride: 64 | ||
| gptq_calibration_batches: 64 | ||
| gptq_enabled: True | ||
| gptq_reserve_seconds: 10.0 | ||
| grad_accum_steps: 1 | ||
| grad_clip_norm: 0.3 | ||
| head_lr: 0.008 | ||
| is_main_process: True | ||
| iterations: 20000 | ||
| ln_scale: True | ||
| local_rank: 0 | ||
| logfile: logs/523d357e-1519-45f9-bc20-69bbfb520c1b.txt | ||
| logit_softcap: 30.0 | ||
| matrix_lr: 0.02 | ||
| max_wallclock_seconds: 600.0 | ||
| min_lr: 0.0 | ||
| mlp_mult: 4.0 | ||
| model_dim: 512 | ||
| model_path: final_model.pt | ||
| muon_backend_steps: 5 | ||
| muon_beta2: 0.95 | ||
| muon_momentum: 0.99 | ||
| muon_momentum_warmup_start: 0.92 | ||
| muon_momentum_warmup_steps: 1500 | ||
| muon_wd: 0.09 | ||
| num_heads: 8 | ||
| num_kv_heads: 4 | ||
| num_layers: 11 | ||
| parallel_start_layer: 7 | ||
| qk_gain_init: 5.0 | ||
| quantized_model_path: final_model.int6.ptz | ||
| rank: 0 | ||
| recur_layers: 4,5 | ||
| recur_start_step: 3000 | ||
| rope_base: 10000.0 | ||
| rope_dims: 16 | ||
| rope_train_seq_len: 2048 | ||
| run_id: 523d357e-1519-45f9-bc20-69bbfb520c1b | ||
| scalar_lr: 0.02 | ||
| seed: 314 | ||
| skip_gates_enabled: True | ||
| sliding_window_enabled: True | ||
| tie_embeddings: True | ||
| tied_embed_init_std: 0.005 | ||
| tied_embed_lr: 0.03 | ||
| tokenizer_path: ./data/tokenizers/fineweb_4096_bpe.model | ||
| train_batch_tokens: 786432 | ||
| train_files: ./data/datasets/fineweb10B_sp4096/fineweb_train_*.bin | ||
| train_log_every: 500 | ||
| train_seq_len: 2048 | ||
| ttt_batch_seqs: 32 | ||
| ttt_chunk_tokens: 32768 | ||
| ttt_enabled: False | ||
| ttt_epochs: 3 | ||
| ttt_freeze_blocks: 0 | ||
| ttt_grad_clip: 1.0 | ||
| ttt_lr: 0.002 | ||
| ttt_momentum: 0.9 | ||
| val_batch_tokens: 524288 | ||
| val_files: ./data/datasets/fineweb10B_sp4096/fineweb_val_*.bin | ||
| val_loss_every: 4000 | ||
| ve_dim: 128 | ||
| ve_enabled: True | ||
| ve_layers: 9,10 | ||
| vocab_size: 4096 | ||
| warmdown_frac: 0.667 | ||
| warmup_steps: 20 | ||
| world_size: 8 | ||
| xsa_last_n: 11 | ||
| train_shards: 80 | ||
| val_tokens: 45508608 | ||
| model_params:34401372 | ||
| gptq:reserving 10s, effective=590000ms | ||
| warmup_step: 1/20 | ||
| warmup_step: 2/20 | ||
| warmup_step: 3/20 | ||
| warmup_step: 4/20 | ||
| warmup_step: 5/20 | ||
| warmup_step: 6/20 | ||
| warmup_step: 10/20 | ||
| warmup_step: 20/20 | ||
| 0/20000 val_loss: 8.3172 val_bpb: 3.6146 | ||
| 1/20000 train_loss: 8.3192 train_time: 0.0m tok/s: 8417806 | ||
| 2/20000 train_loss: 11.4353 train_time: 0.0m tok/s: 8294705 | ||
| 3/20000 train_loss: 8.8870 train_time: 0.0m tok/s: 8200165 | ||
| 4/20000 train_loss: 8.0296 train_time: 0.0m tok/s: 8159612 | ||
| 5/20000 train_loss: 8.6996 train_time: 0.0m tok/s: 8141617 | ||
| 500/20000 train_loss: 3.0633 train_time: 0.8m tok/s: 7939661 | ||
| 1000/20000 train_loss: 2.9869 train_time: 1.7m tok/s: 7911553 | ||
| 1500/20000 train_loss: 2.9710 train_time: 2.5m tok/s: 7908791 | ||
| 2000/20000 train_loss: 2.6991 train_time: 3.3m tok/s: 7909908 | ||
| 2500/20000 train_loss: 2.7380 train_time: 4.1m tok/s: 7911617 | ||
| 3000/20000 train_loss: 2.7805 train_time: 5.0m tok/s: 7913690 | ||
| recurrence:activated at step 3000, virtual_layers=[0, 1, 2, 3, 4, 5, 4, 5, 6, 7, 8, 9, 10] | ||
| 3500/20000 train_loss: 2.6930 train_time: 6.1m tok/s: 7516407 | ||
| 4000/20000 train_loss: 2.6237 train_time: 7.1m tok/s: 7424228 | ||
| 4000/20000 val_loss: 2.6465 val_bpb: 1.1501 | ||
| 4500/20000 train_loss: 2.5745 train_time: 8.0m tok/s: 7355829 | ||
| 5000/20000 train_loss: 2.5200 train_time: 9.0m tok/s: 7301867 | ||
| 5449/20000 val_loss: 2.5373 val_bpb: 1.1027 | ||
| stopping_early: wallclock_cap train_time: 590030ms step: 5449/20000 | ||
| peak memory allocated: 30120 MiB reserved: 30154 MiB | ||
| ema:applying EMA weights | ||
| pre-quantization post-ema val_loss:2.53590710 val_bpb:1.10207601 eval_time:2003ms | ||
| Serialized model: 132406149 bytes | ||
| Code size: 24659 bytes | ||
| GPTQ:collecting Hessians from calibration data... | ||
| GPTQ:collected 66 Hessians in 9.7s | ||
| GPTQ quantization: 66 layers with full GPTQ, 0 fallback to clip-search | ||
| selective_prune: unpruned=16.59MB target=16.0MB | ||
| selective_prune: pruning 4714416/8989294 lowest-error ±1 values (excess=589302B) | ||
| Serialized model int6+brotli: 15148811 bytes | ||
| Total submission size int6+brotli: 15173470 bytes | ||
| final_int6_roundtrip val_loss:2.71751254 val_bpb:1.18099965 eval_time:8188ms | ||
| final_int6_sliding_window val_loss:2.67174312 val_bpb:1.16110878 eval_time:76629ms |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,134 @@ | ||
| W0405 12:16:26.757000 47726 torch/distributed/run.py:803] | ||
| W0405 12:16:26.757000 47726 torch/distributed/run.py:803] ***************************************** | ||
| W0405 12:16:26.757000 47726 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. | ||
| W0405 12:16:26.757000 47726 torch/distributed/run.py:803] ***************************************** | ||
| Hyperparameters: | ||
| adam_eps: 1e-08 | ||
| adam_wd: 0.02 | ||
| beta1: 0.9 | ||
| beta2: 0.95 | ||
| cautious_muon: True | ||
| compressor: brotli | ||
| data_dir: ./data/ | ||
| datasets_dir: ./data/datasets/fineweb10B_sp4096 | ||
| distributed: True | ||
| ema_decay: 0.997 | ||
| embed_lr: 0.6 | ||
| embed_wd: 0.09 | ||
| embedding_dim: 512 | ||
| eval_seq_len: 2048 | ||
| eval_stride: 64 | ||
| gptq_calibration_batches: 64 | ||
| gptq_enabled: True | ||
| gptq_reserve_seconds: 10.0 | ||
| grad_accum_steps: 1 | ||
| grad_clip_norm: 0.3 | ||
| head_lr: 0.008 | ||
| is_main_process: True | ||
| iterations: 20000 | ||
| ln_scale: True | ||
| local_rank: 0 | ||
| logfile: logs/6e35fe72-b1cf-49e5-95dc-2b7d967c8075.txt | ||
| logit_softcap: 30.0 | ||
| matrix_lr: 0.02 | ||
| max_wallclock_seconds: 600.0 | ||
| min_lr: 0.0 | ||
| mlp_mult: 4.0 | ||
| model_dim: 512 | ||
| model_path: final_model.pt | ||
| muon_backend_steps: 5 | ||
| muon_beta2: 0.95 | ||
| muon_momentum: 0.99 | ||
| muon_momentum_warmup_start: 0.92 | ||
| muon_momentum_warmup_steps: 1500 | ||
| muon_wd: 0.09 | ||
| num_heads: 8 | ||
| num_kv_heads: 4 | ||
| num_layers: 11 | ||
| parallel_start_layer: 7 | ||
| qk_gain_init: 5.0 | ||
| quantized_model_path: final_model.int6.ptz | ||
| rank: 0 | ||
| recur_layers: 4,5 | ||
| recur_start_step: 3000 | ||
| rope_base: 10000.0 | ||
| rope_dims: 16 | ||
| rope_train_seq_len: 2048 | ||
| run_id: 6e35fe72-b1cf-49e5-95dc-2b7d967c8075 | ||
| scalar_lr: 0.02 | ||
| seed: 42 | ||
| skip_gates_enabled: True | ||
| sliding_window_enabled: True | ||
| tie_embeddings: True | ||
| tied_embed_init_std: 0.005 | ||
| tied_embed_lr: 0.03 | ||
| tokenizer_path: ./data/tokenizers/fineweb_4096_bpe.model | ||
| train_batch_tokens: 786432 | ||
| train_files: ./data/datasets/fineweb10B_sp4096/fineweb_train_*.bin | ||
| train_log_every: 500 | ||
| train_seq_len: 2048 | ||
| ttt_batch_seqs: 32 | ||
| ttt_chunk_tokens: 32768 | ||
| ttt_enabled: False | ||
| ttt_epochs: 3 | ||
| ttt_freeze_blocks: 0 | ||
| ttt_grad_clip: 1.0 | ||
| ttt_lr: 0.002 | ||
| ttt_momentum: 0.9 | ||
| val_batch_tokens: 524288 | ||
| val_files: ./data/datasets/fineweb10B_sp4096/fineweb_val_*.bin | ||
| val_loss_every: 4000 | ||
| ve_dim: 128 | ||
| ve_enabled: True | ||
| ve_layers: 9,10 | ||
| vocab_size: 4096 | ||
| warmdown_frac: 0.667 | ||
| warmup_steps: 20 | ||
| world_size: 8 | ||
| xsa_last_n: 11 | ||
| train_shards: 80 | ||
| val_tokens: 45508608 | ||
| model_params:34401372 | ||
| gptq:reserving 10s, effective=590000ms | ||
| warmup_step: 1/20 | ||
| warmup_step: 2/20 | ||
| warmup_step: 3/20 | ||
| warmup_step: 4/20 | ||
| warmup_step: 5/20 | ||
| warmup_step: 6/20 | ||
| warmup_step: 10/20 | ||
| warmup_step: 20/20 | ||
| 0/20000 val_loss: 8.3187 val_bpb: 3.6152 | ||
| 1/20000 train_loss: 8.3201 train_time: 0.0m tok/s: 8409702 | ||
| 2/20000 train_loss: 11.3749 train_time: 0.0m tok/s: 8301223 | ||
| 3/20000 train_loss: 8.9270 train_time: 0.0m tok/s: 8232435 | ||
| 4/20000 train_loss: 8.0370 train_time: 0.0m tok/s: 8197540 | ||
| 5/20000 train_loss: 8.6257 train_time: 0.0m tok/s: 8158832 | ||
| 500/20000 train_loss: 3.0617 train_time: 0.8m tok/s: 7952184 | ||
| 1000/20000 train_loss: 2.9819 train_time: 1.7m tok/s: 7929746 | ||
| 1500/20000 train_loss: 2.9701 train_time: 2.5m tok/s: 7920772 | ||
| 2000/20000 train_loss: 2.6938 train_time: 3.3m tok/s: 7916245 | ||
| 2500/20000 train_loss: 2.7396 train_time: 4.1m tok/s: 7918140 | ||
| 3000/20000 train_loss: 2.7789 train_time: 5.0m tok/s: 7919917 | ||
| recurrence:activated at step 3000, virtual_layers=[0, 1, 2, 3, 4, 5, 4, 5, 6, 7, 8, 9, 10] | ||
| 3500/20000 train_loss: 2.6938 train_time: 6.1m tok/s: 7516573 | ||
| 4000/20000 train_loss: 2.6257 train_time: 7.1m tok/s: 7425494 | ||
| 4000/20000 val_loss: 2.6472 val_bpb: 1.1505 | ||
| 4500/20000 train_loss: 2.5745 train_time: 8.0m tok/s: 7356557 | ||
| 5000/20000 train_loss: 2.5209 train_time: 9.0m tok/s: 7302902 | ||
| 5450/20000 val_loss: 2.5381 val_bpb: 1.1030 | ||
| stopping_early: wallclock_cap train_time: 590058ms step: 5450/20000 | ||
| peak memory allocated: 30120 MiB reserved: 30154 MiB | ||
| ema:applying EMA weights | ||
| pre-quantization post-ema val_loss:2.53669680 val_bpb:1.10241920 eval_time:2002ms | ||
| Serialized model: 132406149 bytes | ||
| Code size: 24659 bytes | ||
| GPTQ:collecting Hessians from calibration data... | ||
| GPTQ:collected 66 Hessians in 9.8s | ||
| GPTQ quantization: 66 layers with full GPTQ, 0 fallback to clip-search | ||
| selective_prune: unpruned=16.59MB target=16.0MB | ||
| selective_prune: pruning 4683904/8992183 lowest-error ±1 values (excess=585488B) | ||
| Serialized model int6+brotli: 15154845 bytes | ||
| Total submission size int6+brotli: 15179504 bytes | ||
| final_int6_roundtrip val_loss:2.70991443 val_bpb:1.17769759 eval_time:23408ms | ||
| final_int6_sliding_window val_loss:2.66190192 val_bpb:1.15683191 eval_time:99175ms |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This submission is labeled as a “Record” (folder/PR title), but the reported mean
val_bpb = 1.1604is substantially worse than the current 10min_16mb leaderboard entries (e.g. 1.1228 in the repo README). Consider renaming the PR/folder/READMEnameto avoid implying it’s a new SOTA record if it’s intended as a non-record/ablation submission.