New Medium WR: Smear-MTP (-9.93 seconds)#151
New Medium WR: Smear-MTP (-9.93 seconds)#151ClassicLarry merged 4 commits intoKellerJordan:masterfrom
Conversation
|
This one is interesting. I am hoping to do a validation run and merge of the medium track at the end of the month, but compute limited right now. Looks like over half the improvement is from faster time/step. If this holds up when controlling for hardware and torch version, the only explanation I can think of is the new code happens to induce a better torch compilation path with more efficient kernels. I have seen this on H100 where adding computation gives faster time/step, so seems plausible. A deep dive into what kernels are getting called might uncover additional opportunities. I checked briefly on a much older version of the short track and the lambda first goes to -0.1 and then to -0.02, but I didn't see a clear improvement in loss. I'd be curious what the lambda value is over time on this run. One hypothesis if the lambda goes negative is that the lambda provides a mechanism to remove attention-induced representation blurring. Which is particularly interesting because directionally this concept would scale out to much bigger networks. Thinking from the perspective of previous-token attention heads across layers where seq of tokens is [A, B, C], and thinking about what B is predicting:
In this scenario, the lambda is giving the model more freedom to pull in context from prior positions, without having to worry about getting biased towards accidently predicting whatever the prior positions are predicting. A different angle (one I think is much less likely) is that words in fineweb (maybe) do not tend to repeat, so whatever is likely to immediately follow a position is unlikely to occur 2 positions out. If instead the lambda is positive here, I am at a loss for what the mechanism could be doing. I am cautious to put much weight into any hypothesis without clear data because I think this one is easy to over post-rationalize. |
|
Hmm. The functional form here is so restrictive that I wonder how much the benefits of MTP are applying. The lambda is very close to zero. So the model is never really in a position to help out on the next token without messing up its own prediction. I don’t want to reach too much on a 15/5000 scale improvement. Given the scale of the lambda, the more interesting find to me is the (potential?) inefficiency in the prior torch compilation on 8H100. |
|
The lambda ist pretty small, that's true. However, the improvements in loss per step do seem pretty consistent, especially with it also helping in the larger run. I unfortunately didn't record the lambda for the large run though, so it's hard to confirm. I should definitely do more experiments to find out for sure, but I don't know if I have the time for that. |

Perform Multi-Token Prediction (and adaptive compute) by smearing the previous token's output latents into the current token's output latents like this (compared to the baseline, simplified code):
Statistics
Steps: Reduced from 5550 to 5535.
Times:
Time Stats:
The previous record ran in 1384.6224 seconds (~23.08 minutes) → 9.9298 seconds reduction in wallclock time.
Losses:
Loss Stats:
→ Reached loss of <2.92 with confidence >> 99%.
Re-comparing to baseline
I ran initial tests to make sure that the record doesn't come from PyTorch updates or something like that. Here are the results over time:
The record seems to be legitimate. It's strange that the record run took less time than the baseline; despite being very efficient (adds a single parameter and a sum + concat), it shouldn't be faster than the baseline, so this is random variation. To double-check, here are the losses over the training steps:
This still supports the record.
I also used the old scaled-up run, which is a very simple transformer with ~1B parameters trained for ~10B token, re-ran it, and ran the same thing again with Smear-MTP. The results also support the validity of the record: