Skip to content

New Medium WR: Smear-MTP (-9.93 seconds)#151

Merged
ClassicLarry merged 4 commits intoKellerJordan:masterfrom
snimu:mtp
Dec 31, 2025
Merged

New Medium WR: Smear-MTP (-9.93 seconds)#151
ClassicLarry merged 4 commits intoKellerJordan:masterfrom
snimu:mtp

Conversation

@snimu
Copy link
Copy Markdown
Contributor

@snimu snimu commented Nov 3, 2025

Perform Multi-Token Prediction (and adaptive compute) by smearing the previous token's output latents into the current token's output latents like this (compared to the baseline, simplified code):

# Baseline
x = embed(input_sequence)
for layer in layers:
    x = layer(x)  # Attention + MLP, with residual
x = rms_norm(x)
output_distribution = softmax(language_head(x))

# With Smear-MTP
x = embed(input_sequence)
for layer in layers:
    x = layer(x)  # Attention + MLP, with residual
x = rms_norm(x)
x = torch.cat(
    [
        x[:, :1],  # Nothing to smear into token one
        x[:, 1:] + lambda_smear * x[:, :-1]  # Smear previous latents into current ones
    ]
)
output_distribution = softmax(language_head(x))

Statistics

Steps: Reduced from 5550 to 5535.

Times:

[1374.632, 1373.729, 1373.736, 1373.197, 1376.135, 1374.402, 1381.415, 1373.751, 1376.031, 1373.949, 1373.603, 1372.811, 1373.156, 1373.248, 1374.181, 1374.237, 1373.307, 1372.933, 1385.601, 1374.115, 1372.798, 1374.607, 1376.245, 1373.705, 1372.919, 1373.565]

Time Stats:

  • N: 26
  • Mean: 1374.6926 seconds (~22.91 minutes)
  • Std: 2.83 seconds

The previous record ran in 1384.6224 seconds (~23.08 minutes) → 9.9298 seconds reduction in wallclock time.

Losses:

[2.921089, 2.918772, 2.919393, 2.920201, 2.919555, 2.918825, 2.91906, 2.918602, 2.919583, 2.920038, 2.91947, 2.918655, 2.919565, 2.919967, 2.919333, 2.919548, 2.919505, 2.918065, 2.91992, 2.919905, 2.918552, 2.919111, 2.918989, 2.919336, 2.91804, 2.91896]

Loss Stats:

  • N: 26
  • Mean: 2.9193
  • Std: 0.000678
  • p-value: 1.1266e-05

→ Reached loss of <2.92 with confidence >> 99%.

Re-comparing to baseline

I ran initial tests to make sure that the record doesn't come from PyTorch updates or something like that. Here are the results over time:

image

The record seems to be legitimate. It's strange that the record run took less time than the baseline; despite being very efficient (adds a single parameter and a sum + concat), it shouldn't be faster than the baseline, so this is random variation. To double-check, here are the losses over the training steps:

image

This still supports the record.

I also used the old scaled-up run, which is a very simple transformer with ~1B parameters trained for ~10B token, re-ran it, and ran the same thing again with Smear-MTP. The results also support the validity of the record:

image

@ClassicLarry
Copy link
Copy Markdown
Collaborator

This one is interesting. I am hoping to do a validation run and merge of the medium track at the end of the month, but compute limited right now.

Looks like over half the improvement is from faster time/step. If this holds up when controlling for hardware and torch version, the only explanation I can think of is the new code happens to induce a better torch compilation path with more efficient kernels. I have seen this on H100 where adding computation gives faster time/step, so seems plausible. A deep dive into what kernels are getting called might uncover additional opportunities.

I checked briefly on a much older version of the short track and the lambda first goes to -0.1 and then to -0.02, but I didn't see a clear improvement in loss. I'd be curious what the lambda value is over time on this run. One hypothesis if the lambda goes negative is that the lambda provides a mechanism to remove attention-induced representation blurring. Which is particularly interesting because directionally this concept would scale out to much bigger networks. Thinking from the perspective of previous-token attention heads across layers where seq of tokens is [A, B, C], and thinking about what B is predicting:

  1. In earlier layers, attention pulls a representation of A into B. In an idealized world that representation of A would sit in an isolated subspace of B, but in practice there is some amount of representation blurring. So when 'B' goes to predict a token, it may get biased towards whatever A would have predicted, because its not clear on its own identity.
  2. In later layers, the representation of B has shifted to 'B_prediction' and the representation of A has shifted to 'A_prediction'. So while attention from B wants to pull in the current context of A, it is instead pulling in 'A_prediction'. The value embeds help ground this towards 'A', but only partially. As a result, the representation in B gets biased towards whatever A is predicting.

In this scenario, the lambda is giving the model more freedom to pull in context from prior positions, without having to worry about getting biased towards accidently predicting whatever the prior positions are predicting.

A different angle (one I think is much less likely) is that words in fineweb (maybe) do not tend to repeat, so whatever is likely to immediately follow a position is unlikely to occur 2 positions out.

If instead the lambda is positive here, I am at a loss for what the mechanism could be doing. I am cautious to put much weight into any hypothesis without clear data because I think this one is easy to over post-rationalize.

@snimu
Copy link
Copy Markdown
Contributor Author

snimu commented Nov 4, 2025

Very interesting hypothesis about the lambdas.

I'll address the timing first: it's true that the per-step time is reduced, both when comparing the timings in the previous PR to the record attempt in this one, and in the direct comparison I've made above (which are different runs again). And yes, this is weird haha. I should note though that it only applies in the medium track: in the 1B parameter run, the timing for the Smear-MTP run is slightly worse.

As for the hypothesis for what happens, here are the lambdas over the course of training, for 26 different runs (the mean is bold):

image

It very consistently first goes negative, then positive. That it ends positive seems to contradict your hypothesis (unfortunately!).

My take is that it allows the model to do two things:

  • In general, it's multi-token prediction, which clearly has a positive effect on next-token prediction (see DeepSeek V3 for example).
  • During inference, the same MTP allows the model to choose how much of the token's representation power go into predicting the next token, vs. helping the prediction after that. For easy-to-predict tokens, it can push most of its effort to help in the subsequent tokens, leading to a better utilization of the model-compute. But for hard-to-predict tokens, it doesn't have to do that.

I'm of course very, very unsure about these hypotheses; but at least multi-token prediction helping training seems pretty plausible, as the model is clearly doing MTP when latents are smeared like that.

@ClassicLarry
Copy link
Copy Markdown
Collaborator

Hmm. The functional form here is so restrictive that I wonder how much the benefits of MTP are applying. The lambda is very close to zero. So the model is never really in a position to help out on the next token without messing up its own prediction. I don’t want to reach too much on a 15/5000 scale improvement. Given the scale of the lambda, the more interesting find to me is the (potential?) inefficiency in the prior torch compilation on 8H100.

@snimu
Copy link
Copy Markdown
Contributor Author

snimu commented Nov 4, 2025

The lambda ist pretty small, that's true. However, the improvements in loss per step do seem pretty consistent, especially with it also helping in the larger run. I unfortunately didn't record the lambda for the large run though, so it's hard to confirm.

I should definitely do more experiments to find out for sure, but I don't know if I have the time for that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants