New WR: Muon improvements: faster step, corrected learning rates (30 steps)#146
New WR: Muon improvements: faster step, corrected learning rates (30 steps)#146ClassicLarry merged 2 commits intoKellerJordan:masterfrom
Conversation
|
Wow this is great! I'll get #144 validated and merged to give this one some space. I believe that eff_lr_val was already functioning as a dummy param since max(1, 1/4) was evaluating to 1, which is something I had looked into when adding the attn reshape: The MLP lr issue was introduced in #109 and at the time appeared to have minimal impact on loss. If I'm interpreting these numbers right, then the 0.06->0.03 update is effectively cutting attn LR in half, and also cutting MLP_out lr in half, while leaving MLP_in lr constant. When I had previously tested Muon lr in #140 it seemed pretty stable around 0.06. This is making me wonder if the substantial drop here is tied to NorMuon not having an optimal lr of 0.06. Update of ws schedule from (3, 7, 11) -> (3, 5, 7, 9, 11, 13) is big! I imagine that increasing time per step slightly. If I breakout the timing contributions:
The update to generate_standard_param_groups() means that mlp_up and mlp_down will not share a param group, so theres [11+5 padding] and [11+5 padding] instead of [22+2 padding] for MLP. I believe this will slow down the optimizer.step() for people who want to experiment with adding attn/mlp params. If this is intended to be just a cosmetic update, I don't think the slow down is worth it. Maybe I'm missing another factor here. Clean update to CastedLinear.reset_params! Gonna look at the Muon vectorization more later, looks advanced haha |
|
With respect to this change: When I added this param to Muon I looked at some of these normalization options. Only doing normalization gives slightly different results than NS or polar express, as NS and polar express give a norm that bounces around 1 with some variance. At the time, the hyperparams had already been tuned to perform the full NS() call, so just doing norm() gave worse results and the runtime was ~equivalent, as I suppose the matmuls on a [12,1] were negligible. Conceptually doing a basic norm on 1d param is way simpler so I think its a good change, but if this wasn't directly ablated I think I'll revisit at some point. (maybe Normuon made it so the perturbations in the step size induced by the orthogonalization step are no longer beneficial) |
Definitely could be part of it. The net effect of the change here is we halved the LR on Q, K, V, O, MLP down, attn gate, and smear gate while keeping the same for MLP Up, as you pointed out, so the average LR is now much lower. However, I'm not sure "average LR" is the right indicator, since the region around optimal LR is very nonlinear. Page 8 of this paper might be of interest.
Separating MLP Up and Down was an artifact of testing out separate lr_muls for both of them. Now this is totally redundant, and I should have switched the behavior back.
Yeah just a random cleanup change I should have mentioned. Since we always zero-init the CastedLinear anyway I thought this would improve legibility of code
Totally forgot I added this and didn't check the ablation. Now that you mention it, it might be a big deal. A new result shows that a closer approximation to perfect SVD requires a slightly different learning rate. Since the previous polar express step on 1d is actually a worse approx than the new method, the learning rate changes actually might have a correlated impact. Playing around with lr_mul on the gates might prove very effective. We should potentially scale down the lr on both by 10-20x. |
If I'm understanding correctly- this has no impact on 8 GPU setup, but for people testing on 1 GPU the gradient was previously 8x larger than it should have been. During training we average grads across GPUs via dist.ReduceOp.AVG, but if someone is testing on 1 GPU with grad_accum_steps=8, they will end up summing across all 8 due to 8 iterative calls of .backward(). However, since both AdamW and Muon effectively normalize the magnitude, the impact on loss curve is negligible. Checking the data: When I run this on 1 GPU in google colab I'm getting OOM because the model is having to stay aware of all 8 computational graphs: I believe we can get the same bug fix with: |
|
@ClassicLarry that's much better, good suggestion |
|
Looked at the Muon code, I noticed you fixed a bug where in the last PR NorMuon was getting applied to attention where attention was in shape (768, 768*4) instead of the ideal (4, 768, 768). Curious on the impact here. This PR has enough going on that some ablations of the changes would prob help people learn more here. |
|
@ClassicLarry Good catch. Since we are applying NorMuon to the reshaped attention parameters, the transformation gets applied columnwise instead of rowwise (the condition is |
7bb2bbb to
78a4bb5
Compare
|
I am also getting 138.9s for this record ([138891, 138705, 139154, 139030, 139014]) when I run on nightly 2.10.0.dev20250926+cu126, which is what was used here and for the backout PR. When I retime the prior record I get 140.6s. [140704, 140564, 140595, 140671] So I can merge, but this would be shown as 1.7s improvement. Let me know if this works, or if you want to revisit something here. If there is still a discrepancy, logs will be helpful to isolate the gap. Looking at the individual changes, 1.7s is roughly what I would expect the improvement to be. I notice the window sizes (3, 5, 7, 9, 11, 13) is quite sensitive to the nightly version. Oct22nd and Oct31st nightlies add 1-2s. So we will keep future records on nightly 0926 until someone finds a better nightly to use. |
|
@ClassicLarry I'm still in the process of running ablations but here are some preliminary results on timing: All of these timings are averaged over 5 or 6 runs: Torch 2.9 Stable
Torch dev 09-26
Torch dev 10-01
I'm leaning to reverting the schedule changes (for both LR and WS). From my ablation testing they seem to constitute most of the variance in time. |
|
Merged at 138.8s using 0926 nightly. Based on logs that were shared separately, the hiccups/stuttering that was 'fixed' in PR#140 did not completely eliminate the issue on all hardware, which is where the timing discrepancy above was coming from. The Muon vectorization here helps give better coverage on this (not by meaningfully reducing median time per step, but by reducing likelihood of hiccups/stuttering around step 2000). |
It turns out these are not equivalent, and in fact the 1-GPU loss seems to be a bit higher than the 8-GPU loss for that reason. This is because |
Faster Muon step, corrected learning rates
This record improves the step time of Muon and halves the learning rate.
Timing and Validation
This record improves the final training by 30 steps and decreases time per step by around 1%.
This PR:
Previous PR (timed on same machine):
In total, this corresponds to a$3.30$ second decrease in training time. Roughly half of this is from the decreased number of steps, and the other half should be from the increasing of Muon efficiency. I expect that the timing improvements from the Muon vectorization will vary moderately by machine.
Thank you to Prime Intellect for sponsoring my research.
Changes
(1) LR Adjustment
I found that the Muon learning rate was ~twice as high as it should be, so I've decreased it to
0.03. The lower LR may be in part due to the second order impacts of Normuon.Following the theory that effective learning rate is proportional to
sqrt(output_dim)I have increasedlr_mulon the MLP up-projection to2.0. I have removed the logic that requires all parameters in the same group the share the same learning rate and weight decay.(2) Muon step update
Vectorization
I vectorized several loops inside the Muon
step, which slightly decreases step time. I am guessing we can applytorch.compileto a subpart ofstepfor further gains, as well. I moved the momentum buffers to being properties of groups, not of states, though this requires that we add areset()(similar toYarn).Moved attention reshape
Moving the attention parameter reshape (from
[dim, 4 * dim] -> [4, dim, dim]) to an earlier state ensures that Normuon gets applied columnwise to each parameter instead of rowwise. Empirical testing seems to indicate that Normuon is more effective on the output dim (columnwise) than the input dim (rowwise).(3) Corrections
As noted here, the current logic for
get_lrdoes not flatten out during the iteration extension. I've corrected this issue, as well as a similar issue inget_ws.Additionally, I corrected a subtle bug where gradients were being summed in$<8$ devices.
grad_accum_stepsbut averaged over ranks. In practice this is mostly irrelevant due to magnitude invariance, however it causes minor precision issues for