Skip to content

[New Record] Retie Embedding and LM Head #175

Merged
ClassicLarry merged 6 commits intoKellerJordan:masterfrom
varunneal:retie_lm_head
Dec 22, 2025
Merged

[New Record] Retie Embedding and LM Head #175
ClassicLarry merged 6 commits intoKellerJordan:masterfrom
varunneal:retie_lm_head

Conversation

@varunneal
Copy link
Copy Markdown
Contributor

@varunneal varunneal commented Dec 20, 2025

This record reties the LM Head with the first embed and tunes some of the FP8 Scales. I have incorporated the results of PR#172.

Timing and Validation

This record has 55 fewer steps than the previous record at approximately the same step time.

import scipy.stats
import torch

losses = [3.2777, 3.2790, 3.2776, 3.2792, 3.2760, 3.2792, 3.2767, 3.2763, 3.2770]
times = [123.853, 123.969, 123.929, 123.933, 123.914, 123.906, 123.970, 123.914, 123.964]

print("p=%.4f" % scipy.stats.ttest_1samp(losses, 3.28, alternative="less").pvalue)
# p=0.0002

print("losses:", torch.std_mean(torch.tensor(losses)))
# losses: (std=0.0013, mean=3.2776)

print("time:", torch.std_mean(torch.tensor(times)))
# time: (std=0.0375, mean=123.9280)

Previous record (timed on same machine):

import scipy.stats
import torch

times = [127.051, 127.139, 127.049, 127.147, 127.163, 127.161]

print("time:", torch.std_mean(torch.tensor(times)))
# time: (std=0.0537, mean=127.1183)

These timings show an improvement of $\approx 3.2$ seconds.

Thank you to Prime Intellect for sponsoring my research with GPU credits.

Tied Embed and LM Head

Record 8 untied the LM Head with the Embed layer. This record came with two other changes: namely, adding an RMS Norm after embedding, and initializing the LM Head with zeros.

In this record, I reverse both the first and third change from Record 8. My motivation was to reduce the Adam step time by reducing the amount of parameters that needed to be communicated. I thought that this tradeoff might be reasonable even if it came with an increase in the number of steps. In fact, I found that re-tying the Embed and LM Head could significantly decrease step count. On the other hand, there is basically no impact to the per-step time. I do not know exactly why, though I am guessing the fact that the LM Head weight is both the first and last element in backprop destroys some of the current asynchronous logic in DistAdam.

Decreasing the FP8 Scales

I discovered the above prior to PR#172 but it turns out both benefit from retuning the FP8 Scales. I ran the following experiment on the shared LM Head/Embed weight:

which guided my choice of relative scales for the FP8 weights. These plots also indicate that a simple linear schedule could be optimal for the FP8 scales, though I failed to make this work with simple attempts.

Applying the scales in this Record to PR#172 without tying the LM Head gave me the following results:

import scipy.stats
import torch

losses = [3.2754, 3.2756, 3.2779, 3.2746, 3.2740]

print("losses:", torch.std_mean(torch.tensor(losses)))
# losses: (std=0.0015, mean=3.2755)

It seems that the FP8 weight rescale improves the mean and the high variance seemingly caused by cautious weight decay on the adam parameters.

@ClassicLarry
Copy link
Copy Markdown
Collaborator

Interesting. This result is very unintuitive to me. Curious if it’s because CWD has been added to the lmhead (previously it was effectively off due to lr), or because there is something going on with fp8, or something else. Every other scenario I’ve seen of tying weights has increased loss.

@varunneal
Copy link
Copy Markdown
Contributor Author

I actually got this result prior to CWD. I was also surprised that this is effective. Notably, I've removed the lr mul on the embed weight and set weight init very low (but nonzero). I'm guessing some previous attempts tried to use the very high weight init or learning rate from the embed for the lm head.

@ClassicLarry
Copy link
Copy Markdown
Collaborator

I actually got this result prior to CWD. I was also surprised that this is effective. Notably, I've removed the lr mul on the embed weight and set weight init very low (but nonzero). I'm guessing some previous attempts tried to use the very high weight init or learning rate from the embed for the lm head.

I am curious if the lm head gets untied for the last X steps how things perform. The more unintuitive a result is, the more there is to learn.

@ClassicLarry
Copy link
Copy Markdown
Collaborator

Will merge at 124.5s later, to maintain 3.2s gap.

@ClassicLarry
Copy link
Copy Markdown
Collaborator

Overall I think this record is pretty significant and probably tees up another 2s.

I am getting losses around 3.275 when I let the embed untie from the lm_head for the last 1/3, and copy over the buffers. My shaky hypothesis here is that since the embed gradient is very sparse, it benefits from tagging along with the lm_head gradient early, since there are some core similarities between token relationships in lm_head space and embed space. This sparsity was previously handled by scaling up the lr 75x, which is less stable than the approach here. But later on in training the embed and lm_head are fundamentally different, and benefit from having their own representations. However this doesn't explain why the original change was made to untie the weights.

@varunneal
Copy link
Copy Markdown
Contributor Author

varunneal commented Dec 20, 2025

this doesn't explain why the original change was made to untie the weights.

I think it just showed a win at the time that was somewhat conflated with the other changes (norm after embed, init lm head to zeros). Imo 3.275 won't free up another 2s probably 15 steps at most

edit: Though if we can modify DistAdam to take advantage of the reduced parameter communication I can imagine that giving another second

@ClassicLarry
Copy link
Copy Markdown
Collaborator

ClassicLarry commented Dec 21, 2025

Imo 3.275 won't free up another 2s probably 15 steps at most

agreed, am referring to the broader set of natural follow ons.

-20 steps and 3.2778 loss [3.2806, 3.2797, 3.2755, 3.2771, 3.2772, 3.2775, 3.2768] from simple

def copy_lm_to_embed(self):
        # run at 2/3 of training
        lm_head = self.param_groups[0]['params'][0]
        embed = self.param_groups[-1]['params'][0]
        lm_head_state = self.state[lm_head]
        embed_state = self.state[embed]
        embed_state['step'] = lm_head_state['step']
        embed_state['exp_avg'] = lm_head_state['exp_avg'].clone()
        embed_state['exp_avg_sq'] = lm_head_state['exp_avg_sq'].clone()
        embed.data.copy_(lm_head.data)

# cautious weight decay. mask zeros.
mask = (update * p_slice) > 0

And maybe someone can find a smart way to have the early training of the value embeds also guided by the lm_head, or tune the weight decay rate of value embeds separately.

@varunneal
Copy link
Copy Markdown
Contributor Author

Nice, though I am guessing adding a new parameter adds some time right? There might be a lower rank way to separate the two params.

Theoretically, both the lm head and embed weights are converging to the same underlying weight, which is a stack of perfect vector representations of each token. I'm curious where the embedding weight diverges in practice. It may be to hedge for low representation tokens in an intelligent way

@ClassicLarry
Copy link
Copy Markdown
Collaborator

Nice, though I am guessing adding a new parameter adds some time right? There might be a lower rank way to separate the two params.

Theoretically, both the lm head and embed weights are converging to the same underlying weight, which is a stack of perfect vector representations of each token. I'm curious where the embedding weight diverges in practice. It may be to hedge for low representation tokens in an intelligent way

I havent timed it with proper initial compile, though I'd expect a single param copy to be lightweight. Could be some weird stuff that happens with torch compile though. Don't want to say anything definitive yet on timing.

I would expect lm head and embed to have different values. EG the seq A, B, C, A, B, C, A, B, C... will converge to an embed for A that matches the lm head for B, such that A predicts B. But the embed for A won't match the lm head for A. (Though there are some similarities between the distribution such that synonyms should be close to each other in embed space and also close to each other in lm head space).

@ClassicLarry ClassicLarry merged commit d8377c7 into KellerJordan:master Dec 22, 2025
@ClassicLarry
Copy link
Copy Markdown
Collaborator

I went ahead and merged the record to unblock the next submissions, as the performance was previously validated. Could the logs please be added as well?

@varunneal
Copy link
Copy Markdown
Contributor Author

@ClassicLarry Sorry about that. I think it should be on the new pr

@ClassicLarry
Copy link
Copy Markdown
Collaborator

Am curious on the fp8 scales how much was empirically testing vs following a formula.
self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=100/448, w_s=1.5/448, grad_s=0.75/448)

If I'm looking at the numbers correctly:

  • For x_s, 99.9% of data is kept below 1/4 of the 448 ceiling.
  • For w_s, 99.9% of data is kept below 1/2 of the 448 ceiling.
  • For grad_s, 99.9% of data is kept below 1/10 of the 448 ceiling.

I don't know any theory behind low precision handling.

@varunneal
Copy link
Copy Markdown
Contributor Author

It was definitely more ad hoc than a specific formula but it did help me figure out that x_s and w_s are different orders of magnitude (before they were pretty similarly scaled). I did a 3d grid search which can probably be improved upon.

Actually, the grad_s scale is really deceiving. It's computed in fp8 e5m2, which has a max value of 57,344, much higher than the e4m3 max value.

image

But the higher values for e5m2 are really sparse. In practice, e5m2 js useful for grads because we want to use the high precision values near 0. Therefore we are happy to sacrifice utilization of the full range in order to treat more of the gradient values in high precision. That all said, I just kept the same denominator as the e4m3 values because that's what was there before and I found it useful. Changing the denominator to the real e5m2 max would be fine

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants