New Record: Multi-token prediction and Untie LM Head 2/3rds through training (119.76 seconds)#178
Conversation
|
Under 2 minutes, you did it!! 😃 |
|
Re: CWD, I was thinking of doing the same change, my logic was that for embedding params with sparse gradients, it undoes 'cautious' part of CWD on gradient == 0 and becomes regular weight decay for unused tokens. MTP: It's a really cool method! Did you experiment with DeepSeek v3 style MTP modules (or this slightly simpler approach)? I would have hoped for a method that's "always on" for MTP, though this method being parameter-free and really cheap to compute is very appealing. A baseline ablation , decaying label smoothing, maybe? Or as a poor-man's approximation, fixed secondary and tertiary token targets? (The whole idea seems reminiscent of online label smoothing) Two questions I have are:
|
|
@shenberg I didn't see that paper but I was definitely inspired by some of the work on MTP, especially the new minimax model. For a, I essentially wanted to follow the tripartite batch size schedule so I had some constraints on the search space. Using the shorthand notation for the search space shows there's not really too many variables to optimize over. I had to choose the number of tokens in each phase (eg 2 or 3 or 4) and then what weights each token should get. For the latter, the 1, 1/2, 1/4, ... was my first guess, following the prophet net intuition, though I ran some experiments to confirm it was better than some alternatives. Obviously there's a lot of ways to make an MTP schedule but I like to narrow the search space as much as possible and see if I can get something working. Hopefully future optimizations can be found. For b you raise a good point. We could try using a cheap mask. Maybe a future pr idea would be to mask out all train losses that follow the bos token. Though maybe this is not useful since learning the most common starting document n gram is important at val. Perhaps only apply the mask for parts of training. |
|
Will merge in a couple days at 119.3s (-2.9s) |
|
By the way I got the following ablations:
In previous tests I saw some benefit from cc: @ClassicLarry |
This record implements:
There are additionally the following minor changes
>=to>(contribution from @ClassicLarry)Timing and Validation
This PR has 80 fewer steps than PR177 at a slightly higher step time.
Previous record (timed on same machine):
These timings show$\approx 3$ seconds of improvement.
Thank you to Prime Intellect for sponsoring my research with GPU credits.
Multi-token prediction
This record implements Multi-token Prediction without adding any additional parameters. It uses a weighted average of the cross entropy loss over the next$k$ tokens as the total step loss. In-line with the batch size schedule, MTP follows three phases:
[1, 0.5, 0.25]and decay to[1, 0.5, 0.0].[1, 0.5]and decay to[1, 0.0].In shorthand, this schedule can be described as
[1, 0.5, 0.25->0] -> [1, 0.5->0] -> [1].I experimented with various other ideas for Multi-token prediction, including by adding trainable weights to distinguish the $k$th token predictions. Ultimately, a parameter-free approach seems effective. Intuitively, this naive approach may be effective because early training is learning the$n$ -gram distribution. For this task, MTP increases the signal, and can be considered a cheap proxy of increasing the batch size.
If there additional ablations or experiments you'd like me to run, please let me know in this PR's discussion.
Untying the LM Head and Embed weight
This idea goes back to a comment by @ClassicLarry in PR#175:
I found that the 75x learning rate is probably due more to the high initial weight initialization on the original version of the embedding weights rather than the sparsity of the updates. I found that the untied embedding weights did not benefit from learning rate a multipler, though it may be possible to include, perhaps by using a schedule, or lowering the magnitude of the weights.
Additional Details
>=to>shows minor improvement in Adam.TODO: Include ablations for the above