Skip to content

New WR: NorMuon variance normalization fix and GPU optimizations#168

Merged
ClassicLarry merged 20 commits intoKellerJordan:masterfrom
chrisjmccormick:normuon-optims-and-fixes
Dec 14, 2025
Merged

New WR: NorMuon variance normalization fix and GPU optimizations#168
ClassicLarry merged 20 commits intoKellerJordan:masterfrom
chrisjmccormick:normuon-optims-and-fixes

Conversation

@chrisjmccormick
Copy link
Copy Markdown
Contributor

@chrisjmccormick chrisjmccormick commented Dec 12, 2025

NorMuon & DistAdam Optimizations and Fixes

This update includes:

  1. Algebraic and fusion optimizations to NorMuon
  2. A fix to NorMuon's variance normalization step
  3. Pre-Multiplication of sa_lambdas[1] with the $W^O$ projection (contributed by @shenberg)
import scipy.stats
import torch

# ======== 12-10 Baseline ========
accs = [3.2796, 3.2781, 3.2798, 3.2793]
times  = [132.5520, 132.8170, 132.7390, 132.7790]
p_value = scipy.stats.ttest_1samp(accs, 3.28, alternative='less').pvalue
print(f"p = {p_value:.4}")
print("Loss std, mean: ", torch.std_mean(torch.tensor(accs)))
print("Time std, mean: ", torch.std_mean(torch.tensor(times)))
# p = 0.06323
# Loss std, mean:  (tensor(0.0008), tensor(3.2792))
# Time std, mean:  (tensor(0.1176), tensor(132.7218))

# ======== This Record ========
accs = [3.2758, 3.2807, 3.2776, 3.2757, 3.2793, 3.2764, 3.2795, 3.2788, 3.2785, 3.2748]
times  = [131.2380, 131.2910, 131.1550, 131.1570, 131.2540, 131.3040, 131.1550, 131.1300, 131.1570, 131.2650]
p_value = scipy.stats.ttest_1samp(accs, 3.28, alternative='less').pvalue
print(f"p = {p_value:.4}")
print("Loss std, mean: ", torch.std_mean(torch.tensor(accs)))
print("Time std, mean: ", torch.std_mean(torch.tensor(times)))
# p = 0.002439
# Loss std, mean:  (tensor(0.0020), tensor(3.2777))
# Time std, mean:  (tensor(0.0660), tensor(131.2106))

1. NorMuon Variance Normalization Fix

The existing variance normalization logic was incorrectly identifying the orientation of the smear/attention gates and output heads. I corrected this for the gates by adding a special case to the shape checking logic.
For the output heads, I changed the weight layout to (3,072, 768), with QKVO stacked vertically. This allows all heads to be stored horizontally so that the variance normalization could be applied in a consistent direction.
However, changing the output head orientation provided no improvement in validation loss, and was slightly slower on a GH200, so they are currently still vertical.

I also transposed the MLP weights to maintain their matching shape.

A small experiment showed all variants below the basline on loss:

Plot of final validation loss for variance fixes applied to gates and W_O

(Each result is the average of four runs)

Below is the validation loss over time, averaged over the runs submitted in this PR. Step count is unchanged.

image

2. GPU & Algebraic Optimizations

I implemented compiled helpers for variance normalization and cautious weight decay, and partially fixed an unnecessary memcpy in polar_express.

The below illustration captures the overall improvement to the NorMuon step time.

image

I've included trace files for reference.

See the README for a thorough write-up on all of the above.

Edit: Updated to address what we've learned.

@akash5474
Copy link
Copy Markdown
Contributor

akash5474 commented Dec 12, 2025

Nice work, just a comment on the DistAdam changes

Fixed an issue where Adam weights were being scattered on non-Adam steps due to torch.compile interference.

Linking my comment #166 (comment) just in case you didn't see it but I believe these reduce-scatters come from Normuon L529 which syncs every step and they are not from DistAdam.

This might also affect your ordering changes, but I think the reordering shouldn't matter because the reduce-scatters and all-gathers happen on the same communication stream.

image

The comms are slower than the optimizer step, so even if the optimizer step runs sooner the comms are enqueued in the GPU stream waiting for the previous comms to complete anyway. As long as there is no gap in comms then I think there shouldn't be any difference. Maybe after you merge the optimizers it could help? Though the comms might still be the bottleneck. I tried using separate concurrent communication streams for the reduce-scatters and all-gathers but I found that made things slower

These screenshots are from the trace from my run so things might have changed, maybe there is some performance to be gained here but I'd be interested to see the DistAdam changes independently to see the impact.

@ClassicLarry
Copy link
Copy Markdown
Collaborator

ClassicLarry commented Dec 12, 2025

This one has a lot of awesome stuff going on and things to learn from, and is easily a new record. I'm trying to piece together how much each part is contributing.

First I checked removing Normuon for everything except the MLP, and found roughly no change in behavior. Got losses [3.2781, 3.2801, 3.2810, 3.2758, 3.2766, 3.2770] = 3.2781 and the same runtime. It seems Normuon is only effective for rectangular matrices when applied to the large dimension, as Muon covers the small dimension. Dropping NorMuon from attention doesn't help because other GPUs are busy doing Normuon on MLPs. This honestly surprised me that there wasn't a runtime improvement. I think there should be a custom workload schedule that generates a runtime improvement here, but would take a lot of thinking about which GPU should do what.

The attention lambdas are now creating a different initialization behavior. Previously v and ve were initialized to 0.5 each with 0.5*v+0.5*ve. Now its 1 * (v + 0.5*ve), which means v starts with 67% relative weighting instead of 50%, and the total v projection starts at 50% higher magnitude.

The shape_mult = max(1.0, shape[-2] / shape[-1]) ** 0.5 and shape updates means that mlps and attn now have a 2x multiple on the lr. For each param:

  • attn. Effectively doubles the lr.
  • c_fc. self.c_fc.lr_mul is dropped from 2 to 1, so no net change to lr.
  • c_proj. self.c_fc.lr_mul is increased from 1 to 2, so 4x net change to lr.
    I feel like I must be missing something here because I would think this would impact loss, will check further.

This line in polar_express if G.size(-2) > G.size(-1):X = X.mT is why the MLP shapes were set to [small, big]. Otherwise, polar_express was performing a transpose that slowed down the optimizer. I am hoping to understand more if this is having an impact. Somehow your version is faster, which I don't understand yet.

I am not following this part: Instead, I changed the layout such that all heads are stored horizontally each with shape (128, 768). When I look at the code:
y = F.linear(y, sa_lambdas[1] * self.qkvo_w[self.dim * 3:].type_as(y)). This op will take self.o_w of shape [768,768], transpose it for F.linear, such that y(shape=B, T, hdim) can be matmuled with self.o_w.T(shape=hdim , dim). Which means self.o_w has shape (dim, hdim) before the transpose. I dont see any trick in play that is enabling the shape of o_w to be (hdim,dim).

There is a typo in the original code i wrote:
y = F.linear(y, self.qkvo_w.view(4, self.hdim, self.dim)[3].type_as(y)). This should say (4, self.dim, self.hdim), but ends up not mattering because both are 768. Maybe that is creating some confusion here. F.linear(x,y) applies x @ y.T.

The way you setup the cpu tensor is very interesting, same with the split_baddbmm. Optimization maxing haha.

@shenberg
Copy link
Copy Markdown
Contributor

The attention lambdas are now creating a different initialization behavior. Previously v and ve were initialized to 0.5 each with 0.5*v+0.5*ve. Now its 1 * (v + 0.5*ve), which means v starts with 67% relative weighting instead of 50%, and the total v projection starts at 50% higher magnitude.

I think the intention was to initialize as 1.0, 0.5, which would start equivalently (now it's 1 * (0.5 * v + ve), if we flip them we get 0.5 * (1 * v + ve) which is what used to be the case.

@ClassicLarry
Copy link
Copy Markdown
Collaborator

ClassicLarry commented Dec 12, 2025

<style> </style>
change time time_diff losses    
baseline 131682.5        
adam 131683 0.0005      
cautious 131390 -0.293      
normuon update 131154 -0.236      
normuon_mlp_only 131052 -0.102      
split_baddbmm 130862 -0.19      
update shape 130810.5 -0.0515      
rewrite linear 130732 -0.0785      
wo_lambda 130402.7 -0.32933 3.2797 3.2753 3.2782
normuon_all (not just mlp) 130471.7 0.069 3.2779 3.2776 3.2747
lr_change 130403 -0.06867 3.2788 3.2782 3.2757
running final from source, 130.2
I went through every change individually to understand where the impact was coming from. Adam updates seemed to show no impact, cautious was a clean 0.3s win, normuon compile was another clean 0.2s win, and the split_baddbmm was a clean 0.2s win. Updated the parameter shapes and simplifying the attention linear ops gave roughly a 0.1s improvement. The original shapes were picked back when newtonshulz was used and the transpose() op was incurring a heavy runtime penalty. It looks like this transpose is now happening with zero runtime penalty, so the simpler shaping used here is great.

The new weighting scheme for the attention lambdas seems neutral or slightly better. The lr changes (2x attention, 4x mlp_proj), seem surprisingly neutral, maybe slightly negative. Applying Normuon to gates and attention is maybe slightly worse on runtime and slightly better on loss, overall a tossup since not doing Normuon on attn saves us zero runtime since we are waiting on MLP GPUs.

Im getting 130.200s on the final code. Will wait to merge for a couple days to give the latest record some time. I am quite surprised that the model is so insensitive to the Muon lr. I would have thought that cutting it in half or doubling the step size would cause bigger impact on learning. I wonder if the model is just scaling up its weights to twice the values when the lr doubles to mimic the same behavior.

@chrisjmccormick
Copy link
Copy Markdown
Contributor Author

chrisjmccormick commented Dec 12, 2025

@akash5474 - Thank you for pointing that out, I did miss your comment. I went back and looked more carefully at my trace file for the 11-29 baseline and I think you're right!

I was trying to infer what was what based on the duration of the transfer, but looking at the alignment of the transfers and the compute makes it more clear. Here's what I'm thinking now.

First, here are the parameters in the model and their sizes and data types:

Parameter Name                                    Dimensions       Total Values    Trainable    Type

scalars                                               62 x -                62     True         torch.float32
embed.weight                                      50,304 x 768           36.84M    True         torch.bfloat16
smear_gate.weight                                      1 x -                12     True         torch.bfloat16
value_embeds.0.weight                             50,304 x 768           36.84M    True         torch.bfloat16
value_embeds.1.weight                             50,304 x 768           36.84M    True         torch.bfloat16
value_embeds.2.weight                             50,304 x 768           36.84M    True         torch.bfloat16
blocks.1.attn.qkvo_w                                 768 x 3,072          2.25M    True         torch.float32
blocks.1.attn.attn_gate.weight                         6 x 12               72     True         torch.bfloat16
blocks.1.mlp.c_fc                                    768 x 3,072          2.25M    True         torch.float32
blocks.1.mlp.c_proj                                  768 x 3,072          2.25M    True         torch.float32
lm_head.weight                                    50,304 x 768           36.84M    True         torch.bfloat16

For the NorMuon-only step, the alignment of the transfers and the work makes it clear and confirms your thinking:

image

I assumed the smear gate transfer must be something bigger because of the long duration. Removing the torch.compile resolved this for me:

image

Maybe the lengthy transfer has more to do with the kernels not actually being warmed up? Maybe removing torch.compile helped resolve that bug to some degree, and gave me the above speedup (because these traces are from before the 12-10 record)? That would explain why the change isn't adding any benefit now to my submission vs. the 12-10 record.


This is just informational, not part of the above--but to dig into the above, I looked a little harder at attributing the transfers on the Adam+Muon step. I think the below is a safe assumption:

image

I also noticed this time around that, although we're grouping the parameters by size, it doesn't ultimately matter, because they are sent individually by the backwards hooks.

I asked Gemini to try and puzzle out what the sequence might be, and I think it's rationale is strong:

Trace Order Parameter Name Shape Usage Notes
1 lm_head.weight 50,304 x 768 Last op in forward, first in backward.
2 value_embeds.0.weight 50,304 x 768 Used only in Block 9 (of 12).
3 value_embeds.2.weight 50,304 x 768 Used in Block 11 and Block 2. Waits for Block 2.
4 value_embeds.1.weight 50,304 x 768 Used in Block 10 and Block 1. Waits for Block 1.
5 scalars 62 Tiny size. Waits for end of backward pass.
6 embed.weight 50,304 x 768 Input embedding. Waits for end of backward pass.

It also pointed out that if anything our grouping is harmful and currently suboptimal (which we can confirm from how compute doesn't start until after the first three embedding tables are received)--except of course that it doesn't matter because we're communication bound regardless.

@varunneal
Copy link
Copy Markdown
Contributor

Really insightful charts @chrisjmccormick I would love to read a blog/guide you write on profiling, similar to @akash5474 's guide.

@akash5474
Copy link
Copy Markdown
Contributor

akash5474 commented Dec 12, 2025

Thanks for taking another look and for the detailed explanation @chrisjmccormick! Your images are so nice to look at, I would love to see them in a blog post 😉.

It might be a good idea to remove the torch.compile from _sync_gradient anyway so there's less confusion.

There might be a better grouping/order, but fwiw when I worked on this relying on the backward hook order resulted in a higher loss. I assumed the loss increase was caused by the backward hook ordering messing up the grouping by shape, so instead I used the reverse order to maintain the grouping and that brought the loss back down. I didn't dig into why the loss increased but it could be worth exploring.

I took a look at the trace in your PR and it appears the optimizer step starts later (after the 3rd reduce scatter vs after the 2nd), though it looks like it might finish slightly earlier which reduces the overall compute time of the DistAdam.step.

image

If the parameters are being stepped in the same backward hook order as the gradients are being synced, in the trace we should see a more consistent pattern of compute beginning immediately after a reduce-scatter op completes. To illustrate, here's an example of what it looked like before the backward hook changes:

profiler-trace-adam-step-before-annotated

@chrisjmccormick
Copy link
Copy Markdown
Contributor Author

chrisjmccormick commented Dec 13, 2025

Thanks for the kind words @varunneal and @akash5474! I'll see what I can do re: a blog post. ☺️

@ClassicLarry Thank you so much for all of those experiments and the detailed investigation!

Yeah, you've highlighted multiple mistakes on my part with what you found:

Number 1: Output Heads Didn't Change

You're right, I didn't actually fix the orientation of the output heads, they're still vertical! 🤦🏼‍♂️ The memory layout of the prior implementation was very tricky, I think I spent all of my brain power ensuring I was interpreting that one correctly, and not enough checking my "fix".

I'm running some experiments now with

y = y @ (sa_lambdas[1] * self.qkvo_w[self.dim * 3:].type_as(y))

I know I had it that way at one point, it must have gotten lost while shuffling around experiments.
We'll see if it actually does anything--from you experiments removing the variance normalization, Larry, it sounds like it probably won't.

Number 2: Learning Rate Multipliers

I completely glossed over that detail while making my changes / didn't bother to understand it.
Here's what I'm seeing:

  • I think attention lr should be unchanged because this shape multiplier step comes after the attention weight updates have been reshaped into four squares (here).
  • The MLP implementation included a "corrective multiplier" here to account for the mlp output weights being in a different orientation than the shape_mult heuristic expects. I was oblivious to what that was about, and it's now on the input weight matrix instead.

So the MLP used to have a 2x multiplier on its output learning rate, and now it has a 4x on its input and still a 2x on its output.

I'm going to look into that--I'm thinking the "double the learning rate for tall matrices" mechanism might not actually apply in the context of a Transformer FFN, because the functional units are the 768-dim vectors on both the input and the output side. It would make more sense to me in a standard MLP. I'm going to try removing it and/or applying 2x to both and will report back.

@varunneal
Copy link
Copy Markdown
Contributor

varunneal commented Dec 13, 2025

I'll just add slightly to the learning rate discussion that the expected muon learning rate is propto sqrt(fan_in/fan_out). So the MLP Up should have learning rate 0.25x - 0.5x the learning rate of MLP down lr (not 2 times as much). But this is all theory, and practice may invalidate this.

@chrisjmccormick
Copy link
Copy Markdown
Contributor Author

@varunneal Interesting! I'll add that ratio to the list to try.

@ClassicLarry
Copy link
Copy Markdown
Collaborator

  • I think attention lr should be unchanged because this shape multiplier step comes after the attention weight updates have been reshaped into four squares (here).

The shape multiplier comes from param.shape, which is in [4*dim, dim] for attn and hasn't been reshaped. But im not sure there is a performance diff either way.

@chrisjmccormick
Copy link
Copy Markdown
Contributor Author

chrisjmccormick commented Dec 13, 2025

Ok, here are some updates.

I misspoke earlier about the mlp lr.
Previously: we had the lr multipler here:

self.c_fc.lr_mul = 2.

On the input neurons, which had shape (768, 3072).

Then in NorMuon, here we would increase the lr for tall matrices:

max(1., param_shape[-2] / param_shape[-1]) ** 0.5
                    * ref_param.new_tensor(
                        [getattr(param, "lr_mul", 1.0) for param in params[module_idx:module_idx + num_params]]
                    ).view(-1, 1, 1)

Which didn't apply to either of the mlp matrices--they were both wide.

So we had: 2x on the input, 1x on the output.

With my changes:

  • The mlp matrices are 4x taller than wide, so this gives them both a 2x multiplier.
  • I moved the additional, explicit 2x multiplier to the output matrix, so now we have:
    • 2x on the input neurons, 4x on the output neurons.

Which I believe matches what you're suggesting, @varunneal? Half as high on $W^{up}$ (input neurons) vs. $W^{down}$ (output neurons)?

I tested out removing the explicit 2x multiplier on the output matrix, which results in attention and mlp weights all having the same 2x multiplier to their learning rate.

On the GH200, this actually provided a speed-up (maybe related somehow to the fact that having a 4x lr in the mix means an additional kernel vs. if they're all 2x?).

I also tried removing these lr multipliers completely from attn and mlps, which is the third row. No benefit.

image

Changing them all to 2x on the 8xH100 didn't have any benefit, though. (Which is kind of nice? No incentive here to go back and re-run it):

image

So it seems safe to leave as-is, and I'll add a note to my README about this.

Adding trace files for reference which were created with the actual submitted code.
There were quite a few mistakes in my understanding of the original submission. I've gone through to remove them and incorporate our more recent insights.
@chrisjmccormick
Copy link
Copy Markdown
Contributor Author

Alright, I believe I've corrected all of the misconceptions in the README, code comments, and PR description.
Thanks everyone again for the feedback and analysis!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants