New WR: NorMuon variance normalization fix and GPU optimizations#168
Conversation
This script has my submission prior to incorporating pre-multiplication with W^O
|
Nice work, just a comment on the DistAdam changes
Linking my comment #166 (comment) just in case you didn't see it but I believe these reduce-scatters come from Normuon L529 which syncs every step and they are not from DistAdam. This might also affect your ordering changes, but I think the reordering shouldn't matter because the reduce-scatters and all-gathers happen on the same communication stream.
The comms are slower than the optimizer step, so even if the optimizer step runs sooner the comms are enqueued in the GPU stream waiting for the previous comms to complete anyway. As long as there is no gap in comms then I think there shouldn't be any difference. Maybe after you merge the optimizers it could help? Though the comms might still be the bottleneck. I tried using separate concurrent communication streams for the reduce-scatters and all-gathers but I found that made things slower These screenshots are from the trace from my run so things might have changed, maybe there is some performance to be gained here but I'd be interested to see the DistAdam changes independently to see the impact. |
|
This one has a lot of awesome stuff going on and things to learn from, and is easily a new record. I'm trying to piece together how much each part is contributing. First I checked removing Normuon for everything except the MLP, and found roughly no change in behavior. Got losses [3.2781, 3.2801, 3.2810, 3.2758, 3.2766, 3.2770] = 3.2781 and the same runtime. It seems Normuon is only effective for rectangular matrices when applied to the large dimension, as Muon covers the small dimension. Dropping NorMuon from attention doesn't help because other GPUs are busy doing Normuon on MLPs. This honestly surprised me that there wasn't a runtime improvement. I think there should be a custom workload schedule that generates a runtime improvement here, but would take a lot of thinking about which GPU should do what. The attention lambdas are now creating a different initialization behavior. Previously v and ve were initialized to 0.5 each with The
This line in polar_express I am not following this part: There is a typo in the original code i wrote: The way you setup the cpu tensor is very interesting, same with the split_baddbmm. Optimization maxing haha. |
I think the intention was to initialize as 1.0, 0.5, which would start equivalently (now it's 1 * (0.5 * v + ve), if we flip them we get 0.5 * (1 * v + ve) which is what used to be the case. |
<style>
</style>
The new weighting scheme for the attention lambdas seems neutral or slightly better. The lr changes (2x attention, 4x mlp_proj), seem surprisingly neutral, maybe slightly negative. Applying Normuon to gates and attention is maybe slightly worse on runtime and slightly better on loss, overall a tossup since not doing Normuon on attn saves us zero runtime since we are waiting on MLP GPUs. Im getting 130.200s on the final code. Will wait to merge for a couple days to give the latest record some time. I am quite surprised that the model is so insensitive to the Muon lr. I would have thought that cutting it in half or doubling the step size would cause bigger impact on learning. I wonder if the model is just scaling up its weights to twice the values when the lr doubles to mimic the same behavior. |
|
@akash5474 - Thank you for pointing that out, I did miss your comment. I went back and looked more carefully at my trace file for the 11-29 baseline and I think you're right! I was trying to infer what was what based on the duration of the transfer, but looking at the alignment of the transfers and the compute makes it more clear. Here's what I'm thinking now. First, here are the parameters in the model and their sizes and data types: For the NorMuon-only step, the alignment of the transfers and the work makes it clear and confirms your thinking:
I assumed the smear gate transfer must be something bigger because of the long duration. Removing the torch.compile resolved this for me:
Maybe the lengthy transfer has more to do with the kernels not actually being warmed up? Maybe removing torch.compile helped resolve that bug to some degree, and gave me the above speedup (because these traces are from before the 12-10 record)? That would explain why the change isn't adding any benefit now to my submission vs. the 12-10 record. This is just informational, not part of the above--but to dig into the above, I looked a little harder at attributing the transfers on the Adam+Muon step. I think the below is a safe assumption:
I also noticed this time around that, although we're grouping the parameters by size, it doesn't ultimately matter, because they are sent individually by the backwards hooks. I asked Gemini to try and puzzle out what the sequence might be, and I think it's rationale is strong:
It also pointed out that if anything our grouping is harmful and currently suboptimal (which we can confirm from how compute doesn't start until after the first three embedding tables are received)--except of course that it doesn't matter because we're communication bound regardless. |
|
Really insightful charts @chrisjmccormick I would love to read a blog/guide you write on profiling, similar to @akash5474 's guide. |
|
Thanks for taking another look and for the detailed explanation @chrisjmccormick! Your images are so nice to look at, I would love to see them in a blog post 😉. It might be a good idea to remove the There might be a better grouping/order, but fwiw when I worked on this relying on the backward hook order resulted in a higher loss. I assumed the loss increase was caused by the backward hook ordering messing up the grouping by shape, so instead I used the reverse order to maintain the grouping and that brought the loss back down. I didn't dig into why the loss increased but it could be worth exploring. I took a look at the trace in your PR and it appears the optimizer step starts later (after the 3rd reduce scatter vs after the 2nd), though it looks like it might finish slightly earlier which reduces the overall compute time of the DistAdam.step.
If the parameters are being stepped in the same backward hook order as the gradients are being synced, in the trace we should see a more consistent pattern of compute beginning immediately after a reduce-scatter op completes. To illustrate, here's an example of what it looked like before the backward hook changes:
|
|
Thanks for the kind words @varunneal and @akash5474! I'll see what I can do re: a blog post. @ClassicLarry Thank you so much for all of those experiments and the detailed investigation! Yeah, you've highlighted multiple mistakes on my part with what you found: Number 1: Output Heads Didn't Change You're right, I didn't actually fix the orientation of the output heads, they're still vertical! 🤦🏼♂️ The memory layout of the prior implementation was very tricky, I think I spent all of my brain power ensuring I was interpreting that one correctly, and not enough checking my "fix". I'm running some experiments now with y = y @ (sa_lambdas[1] * self.qkvo_w[self.dim * 3:].type_as(y))I know I had it that way at one point, it must have gotten lost while shuffling around experiments. Number 2: Learning Rate Multipliers I completely glossed over that detail while making my changes / didn't bother to understand it.
So the MLP used to have a 2x multiplier on its output learning rate, and now it has a 4x on its input and still a 2x on its output. I'm going to look into that--I'm thinking the "double the learning rate for tall matrices" mechanism might not actually apply in the context of a Transformer FFN, because the functional units are the 768-dim vectors on both the input and the output side. It would make more sense to me in a standard MLP. I'm going to try removing it and/or applying 2x to both and will report back. |
|
I'll just add slightly to the learning rate discussion that the expected muon learning rate is propto sqrt(fan_in/fan_out). So the MLP Up should have learning rate 0.25x - 0.5x the learning rate of MLP down lr (not 2 times as much). But this is all theory, and practice may invalidate this. |
|
@varunneal Interesting! I'll add that ratio to the list to try. |
The shape multiplier comes from param.shape, which is in [4*dim, dim] for attn and hasn't been reshaped. But im not sure there is a performance diff either way. |
|
Ok, here are some updates. I misspoke earlier about the mlp lr. self.c_fc.lr_mul = 2.On the input neurons, which had shape (768, 3072). Then in NorMuon, here we would increase the lr for tall matrices: max(1., param_shape[-2] / param_shape[-1]) ** 0.5
* ref_param.new_tensor(
[getattr(param, "lr_mul", 1.0) for param in params[module_idx:module_idx + num_params]]
).view(-1, 1, 1)Which didn't apply to either of the mlp matrices--they were both wide. So we had: 2x on the input, 1x on the output. With my changes:
Which I believe matches what you're suggesting, @varunneal? Half as high on I tested out removing the explicit 2x multiplier on the output matrix, which results in attention and mlp weights all having the same 2x multiplier to their learning rate. On the GH200, this actually provided a speed-up (maybe related somehow to the fact that having a 4x lr in the mix means an additional kernel vs. if they're all 2x?). I also tried removing these lr multipliers completely from attn and mlps, which is the third row. No benefit.
Changing them all to 2x on the 8xH100 didn't have any benefit, though. (Which is kind of nice? No incentive here to go back and re-run it):
So it seems safe to leave as-is, and I'll add a note to my README about this. |
Adding trace files for reference which were created with the actual submitted code.
There were quite a few mistakes in my understanding of the original submission. I've gone through to remove them and incorporate our more recent insights.
|
Alright, I believe I've corrected all of the misconceptions in the README, code comments, and PR description. |








NorMuon & DistAdam Optimizations and Fixes
This update includes:
sa_lambdas[1]with the1. NorMuon Variance Normalization Fix
The existing variance normalization logic was incorrectly identifying the orientation of the smear/attention gates and output heads. I corrected this for the gates by adding a special case to the shape checking logic.
For the output heads, I changed the weight layout to (3,072, 768), with QKVO stacked vertically. This allows all heads to be stored horizontally so that the variance normalization could be applied in a consistent direction.
However, changing the output head orientation provided no improvement in validation loss, and was slightly slower on a GH200, so they are currently still vertical.
I also transposed the MLP weights to maintain their matching shape.
A small experiment showed all variants below the basline on loss:
(Each result is the average of four runs)
Below is the validation loss over time, averaged over the runs submitted in this PR. Step count is unchanged.
2. GPU & Algebraic Optimizations
I implemented compiled helpers for variance normalization and cautious weight decay, and partially fixed an unnecessary memcpy in
polar_express.The below illustration captures the overall improvement to the NorMuon step time.
I've included trace files for reference.
See the README for a thorough write-up on all of the above.
Edit: Updated to address what we've learned.