Skip to content

New WR: Muon improvements: faster step, corrected learning rates (30 steps)#146

Merged
ClassicLarry merged 2 commits intoKellerJordan:masterfrom
varunneal:fix-lr
Nov 6, 2025
Merged

New WR: Muon improvements: faster step, corrected learning rates (30 steps)#146
ClassicLarry merged 2 commits intoKellerJordan:masterfrom
varunneal:fix-lr

Conversation

@varunneal
Copy link
Copy Markdown
Contributor

@varunneal varunneal commented Oct 28, 2025

Faster Muon step, corrected learning rates

This record improves the step time of Muon and halves the learning rate.

Timing and Validation

This record improves the final training by 30 steps and decreases time per step by around 1%.

This PR:

import scipy.stats
import torch

losses = [3.2804, 3.2754, 3.2753, 3.2800, 3.2780, 3.2813, 3.2778, 3.2771, 3.2780, 3.2783]
times = [139.441, 139.780, 139.889, 139.464, 139.761, 139.411, 139.555, 139.570, 139.804, 139.847]

print("p=%.4f" % scipy.stats.ttest_1samp(losses, 3.28, alternative="less").pvalue)
# p=0.0083

print("losses:", torch.std_mean(torch.tensor(losses)))
# losses: (std=0.0020, mean=3.2782)

print("time:", torch.std_mean(torch.tensor(times)))
# time: (std=0.1825, mean=139.6522)

Previous PR (timed on same machine):

import scipy.stats
import torch

times = [143.018, 142.641, 142.789, 143.072, 143.241]

print("time:", torch.std_mean(torch.tensor(times)))
# time: (std=0.2375, mean=142.9522)

In total, this corresponds to a $3.30$ second decrease in training time. Roughly half of this is from the decreased number of steps, and the other half should be from the increasing of Muon efficiency. I expect that the timing improvements from the Muon vectorization will vary moderately by machine.

Thank you to Prime Intellect for sponsoring my research.

Changes

(1) LR Adjustment

I found that the Muon learning rate was ~twice as high as it should be, so I've decreased it to 0.03. The lower LR may be in part due to the second order impacts of Normuon.

Following the theory that effective learning rate is proportional to sqrt(output_dim) I have increased lr_mul on the MLP up-projection to 2.0. I have removed the logic that requires all parameters in the same group the share the same learning rate and weight decay.

(2) Muon step update

Vectorization

I vectorized several loops inside the Muon step, which slightly decreases step time. I am guessing we can apply torch.compile to a subpart of step for further gains, as well. I moved the momentum buffers to being properties of groups, not of states, though this requires that we add a reset() (similar to Yarn).

Moved attention reshape

Moving the attention parameter reshape (from [dim, 4 * dim] -> [4, dim, dim]) to an earlier state ensures that Normuon gets applied columnwise to each parameter instead of rowwise. Empirical testing seems to indicate that Normuon is more effective on the output dim (columnwise) than the input dim (rowwise).

(3) Corrections

As noted here, the current logic for get_lr does not flatten out during the iteration extension. I've corrected this issue, as well as a similar issue in get_ws.

Additionally, I corrected a subtle bug where gradients were being summed in grad_accum_steps but averaged over ranks. In practice this is mostly irrelevant due to magnitude invariance, however it causes minor precision issues for $<8$ devices.

@varunneal varunneal changed the title New PR: Muon improvements: faster step, corrected learning rates (-3.2 seconds) New WR: Muon improvements: faster step, corrected learning rates (-3.2 seconds) Oct 28, 2025
@ClassicLarry
Copy link
Copy Markdown
Collaborator

Wow this is great! I'll get #144 validated and merged to give this one some space.

I believe that eff_lr_val was already functioning as a dummy param since max(1, 1/4) was evaluating to 1, which is something I had looked into when adding the attn reshape:

import torch
import torch.nn as nn
p_example = nn.Parameter(torch.empty(768, 768*4))
max(1, p_example.size(-2) / p_example.size(-1)) ** 0.5 # 1

The MLP lr issue was introduced in #109 and at the time appeared to have minimal impact on loss.
#109 (comment).
When the PR was originally made, it was believed to just be an engineering optimization so loss was only shown for 1 run of 3.2780. Perhaps the loss impact was meaningful on average and it got hidden by the following PR.

If I'm interpreting these numbers right, then the 0.06->0.03 update is effectively cutting attn LR in half, and also cutting MLP_out lr in half, while leaving MLP_in lr constant. When I had previously tested Muon lr in #140 it seemed pretty stable around 0.06. This is making me wonder if the substantial drop here is tied to NorMuon not having an optimal lr of 0.06.

Update of ws schedule from (3, 7, 11) -> (3, 5, 7, 9, 11, 13) is big! I imagine that increasing time per step slightly. If I breakout the timing contributions:

  • 30 steps: -1.8s (driven by NorMuon lr decrease, more granular + larger attention windows)
  • slightly avg larger window +X seconds? (from schedule update + no iteration_extension)
  • Muon updates -Y seconds?
  • Total = -3.22

The update to generate_standard_param_groups() means that mlp_up and mlp_down will not share a param group, so theres [11+5 padding] and [11+5 padding] instead of [22+2 padding] for MLP. I believe this will slow down the optimizer.step() for people who want to experiment with adding attn/mlp params. If this is intended to be just a cosmetic update, I don't think the slow down is worth it. Maybe I'm missing another factor here.

Clean update to CastedLinear.reset_params! Gonna look at the Muon vectorization more later, looks advanced haha

@ClassicLarry
Copy link
Copy Markdown
Collaborator

With respect to this change:

elif params[module_idx].label == "smear_gate":
    # dividing by magnitude is equivalent of SVN for 1d tensors
    v_chunk = updated_grads / (updated_grads.norm(dim=(-2, -1), keepdim=True).clamp_min(1e-10))

When I added this param to Muon I looked at some of these normalization options. Only doing normalization gives slightly different results than NS or polar express, as NS and polar express give a norm that bounces around 1 with some variance.

def zeropower_via_newtonschulz5(G, steps=5):
    a, b, c = (3.4445, -4.7750,  2.0315)
    X = G.bfloat16()
    if G.size(-2) > G.size(-1):
        X = X.mT
    X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
    for _ in range(steps):
        A = X @ X.mT
        B = b * A + c * A @ A
        X = a * X + B @ X

    if G.size(-2) > G.size(-1):
        X = X.mT
    return X

updated_grads = torch.rand((10,1))
out1 = updated_grads / (updated_grads.norm(dim=(-2, -1), keepdim=True).clamp_min(1e-10))
out2 = zeropower_via_newtonschulz5(updated_grads)
out3 = polar_express(updated_grads.to('cuda'))
out1, out2, out3

At the time, the hyperparams had already been tuned to perform the full NS() call, so just doing norm() gave worse results and the runtime was ~equivalent, as I suppose the matmuls on a [12,1] were negligible. Conceptually doing a basic norm on 1d param is way simpler so I think its a good change, but if this wasn't directly ablated I think I'll revisit at some point. (maybe Normuon made it so the perturbations in the step size induced by the orthogonalization step are no longer beneficial)

@varunneal
Copy link
Copy Markdown
Contributor Author

varunneal commented Oct 28, 2025

@ClassicLarry

This is making me wonder if the substantial drop here is tied to NorMuon not having an optimal lr of 0.06.

Definitely could be part of it. The net effect of the change here is we halved the LR on Q, K, V, O, MLP down, attn gate, and smear gate while keeping the same for MLP Up, as you pointed out, so the average LR is now much lower. However, I'm not sure "average LR" is the right indicator, since the region around optimal LR is very nonlinear. Page 8 of this paper might be of interest.

The update to generate_standard_param_groups() means that mlp_up and mlp_down will not share a param group, so theres [11+5 padding] and [11+5 padding] instead of [22+2 padding] for MLP.

Separating MLP Up and Down was an artifact of testing out separate lr_muls for both of them. Now this is totally redundant, and I should have switched the behavior back.

CastedLinear.reset_params

Yeah just a random cleanup change I should have mentioned. Since we always zero-init the CastedLinear anyway I thought this would improve legibility of code

# dividing by magnitude is equivalent of SVN for 1d tensors

Totally forgot I added this and didn't check the ablation. Now that you mention it, it might be a big deal. A new result shows that a closer approximation to perfect SVD requires a slightly different learning rate. Since the previous polar express step on 1d is actually a worse approx than the new method, the learning rate changes actually might have a correlated impact.

Playing around with lr_mul on the gates might prove very effective. We should potentially scale down the lr on both by 10-20x.

@ClassicLarry
Copy link
Copy Markdown
Collaborator

Additionally, I corrected a subtle bug where gradients were being summed in grad_accum_steps but averaged over ranks. In practice this is mostly irrelevant due to magnitude invariance, however it causes minor precision issues for <8 devices.

If I'm understanding correctly- this has no impact on 8 GPU setup, but for people testing on 1 GPU the gradient was previously 8x larger than it should have been. During training we average grads across GPUs via dist.ReduceOp.AVG, but if someone is testing on 1 GPU with grad_accum_steps=8, they will end up summing across all 8 due to 8 iterative calls of .backward(). However, since both AdamW and Muon effectively normalize the magnitude, the impact on loss curve is negligible. Checking the data:

#old code on 1 gpu at step1
model.lm_head.weight.grad.norm(2) #176128

#new code on 1 gpu at step1
model.lm_head.weight.grad.norm(2) #21888, ~8x smaller norm

#single call to backward() with no /grad_accum_norm on 1 GPU to validate
model.lm_head.weight.grad.norm(2) #23040, ~8x smaller norm (What we would get on 8GPU)

When I run this on 1 GPU in google colab I'm getting OOM because the model is having to stay aware of all 8 computational graphs:

loss = 0
for _ in range(grad_accum_steps):
    inputs, targets, cum_seqlens = next(train_loader)
    loss += model(inputs, targets, cum_seqlens, ws) / grad_accum_steps
loss.backward()

I believe we can get the same bug fix with:

for _ in range(grad_accum_steps):
    inputs, targets, cum_seqlens = next(train_loader)
    loss = model(inputs, targets, cum_seqlens, ws) / grad_accum_steps
    loss.backward()

@varunneal
Copy link
Copy Markdown
Contributor Author

@ClassicLarry that's much better, good suggestion

@ClassicLarry
Copy link
Copy Markdown
Collaborator

Looked at the Muon code, I noticed you fixed a bug where in the last PR NorMuon was getting applied to attention where attention was in shape (768, 768*4) instead of the ideal (4, 768, 768). Curious on the impact here. This PR has enough going on that some ablations of the changes would prob help people learn more here.

@varunneal
Copy link
Copy Markdown
Contributor Author

@ClassicLarry Good catch. Since we are applying NorMuon to the reshaped attention parameters, the transformation gets applied columnwise instead of rowwise (the condition is $m \geq n \to$ columnwise)

@varunneal varunneal force-pushed the fix-lr branch 2 times, most recently from 7bb2bbb to 78a4bb5 Compare October 31, 2025 20:36
@ClassicLarry
Copy link
Copy Markdown
Collaborator

ClassicLarry commented Nov 1, 2025

I am also getting 138.9s for this record ([138891, 138705, 139154, 139030, 139014]) when I run on nightly 2.10.0.dev20250926+cu126, which is what was used here and for the backout PR.

When I retime the prior record I get 140.6s. [140704, 140564, 140595, 140671]

So I can merge, but this would be shown as 1.7s improvement. Let me know if this works, or if you want to revisit something here. If there is still a discrepancy, logs will be helpful to isolate the gap. Looking at the individual changes, 1.7s is roughly what I would expect the improvement to be.

I notice the window sizes (3, 5, 7, 9, 11, 13) is quite sensitive to the nightly version. Oct22nd and Oct31st nightlies add 1-2s. So we will keep future records on nightly 0926 until someone finds a better nightly to use.

@varunneal
Copy link
Copy Markdown
Contributor Author

@ClassicLarry I'm still in the process of running ablations but here are some preliminary results on timing:

All of these timings are averaged over 5 or 6 runs:

Torch 2.9 Stable

  • 141.433s vs 139.3244s: 2.1086s difference
  • 142.6292s vs 140.8714s: 1.755s difference (different GPU instance)

Torch dev 09-26

  • 141.9850s vs 139.2808s: 2.705s difference

Torch dev 10-01

  • 142.116s vs 141.4672s: 0.6488s difference

I'm leaning to reverting the schedule changes (for both LR and WS). From my ablation testing they seem to constitute most of the variance in time.

@varunneal varunneal changed the title New WR: Muon improvements: faster step, corrected learning rates (-3.2 seconds) New WR: Muon improvements: faster step, corrected learning rates (-3+ seconds) Nov 6, 2025
@varunneal varunneal changed the title New WR: Muon improvements: faster step, corrected learning rates (-3+ seconds) New WR: Muon improvements: faster step, corrected learning rates (30 steps) Nov 6, 2025
@ClassicLarry ClassicLarry merged commit edf5cdb into KellerJordan:master Nov 6, 2025
@ClassicLarry
Copy link
Copy Markdown
Collaborator

Merged at 138.8s using 0926 nightly. Based on logs that were shared separately, the hiccups/stuttering that was 'fixed' in PR#140 did not completely eliminate the issue on all hardware, which is where the timing discrepancy above was coming from. The Muon vectorization here helps give better coverage on this (not by meaningfully reducing median time per step, but by reducing likelihood of hiccups/stuttering around step 2000).

@varunneal
Copy link
Copy Markdown
Contributor Author

@ClassicLarry

When I run this on 1 GPU in google colab I'm getting OOM because the model is having to stay aware of all 8 computational graphs:

loss = 0
for _ in range(grad_accum_steps):
inputs, targets, cum_seqlens = next(train_loader)
loss += model(inputs, targets, cum_seqlens, ws) / grad_accum_steps
loss.backward()
I believe we can get the same bug fix with:

for _ in range(grad_accum_steps):
inputs, targets, cum_seqlens = next(train_loader)
loss = model(inputs, targets, cum_seqlens, ws) / grad_accum_steps
loss.backward()

It turns out these are not equivalent, and in fact the 1-GPU loss seems to be a bit higher than the 8-GPU loss for that reason. This is because SUM(g_i) / 8 (8 GPUs) is more precise for lower values of grads than SUM(g_i / 8) when we are accumulating in BF16. In fact, I think the best change here might be to remove the grad_accum_steps division entirely. I think the overflow problem (grads too high) is much more rare than the underflow issue. Alternatively, we can divide the grads by grad_accum_steps after we have computed the backward. Easy to implement in step_optimizers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants