New Record: Bigram Hash Embedding (-5.6s, -165 steps)#201
New Record: Bigram Hash Embedding (-5.6s, -165 steps)#201ClassicLarry merged 1 commit intoKellerJordan:masterfrom
Conversation
|
did you run ablation test on bigram vs Fix partial key offset? |
ya, fixing partial key offset dims was roughly 0.001 loss or half a sec. |
|
Did some validations because the amount of loss decrease recently has made me more cautious on data leakage. |
|
Wow! 5% improvement at this point is big. Congrats. Having Bigram Hash Embedding yield such an improvement is exciting. Also, it's interesting that trigrams or multiple hashes didn't show additional improvement. Sounds to me like there's some interesting stuff to understand there. |
I want to clarify that only my particular implementation of trigrams and multiple hashes didn't help. In an execution driven field, negative results are close to meaningless. I expect there are likely extensions that will be beneficial with the right implementation nuances. |
|
I wrote an initial sparse-comms implementation for the gradient reduce-scatter (param updates are dense at the moment). Works using the all_to_all_single strategy (I believe it's the horovod embedding all-to-all strategy):
This works but slows down run-time by about a second and a half since implementation is sub-optimal. I suspect that even with a good implementation, it might be worth it only for small batch sizes, as for large ones I'm expecting about 65% sparsity. I'll throw some Update: |
|
This was absolutely incredible @ LarryTheLegend. 🔥 |
|
Wow, this is impressive! |
|
I wonder if the "Smear token embeddings forward 1 position" (PR #130) can/should be removed now that the model is using bigram embedding
|
|
I just confirmed: removing the smear token embeddings and smear gate makes the training faster, all because the Bigram Hash Embeddings are sufficient and powerful enough to overcome the effects of smear gates I will post the PR tomorrow |
I had tested this before I raised the Pr but found that 1/5 runs would randomly loss spike. But perhaps my tests were just atypically unlucky. |
Indeed, you are right. It was just a lucky run on a machine with a single H100. Tested now many times on 4xH100 and the training loss was not smaller when removing the smear embeddings. Just note that new scalars were added to Line 1183 in 7d502b9 It should be: pad = -(num_layers * 4 + 3) % dist.get_world_size() |
|
Good catch on the padding, which is redundant now that all_reduce is being used on the scalars. 3 of the bigram lambdas go negative, which I had noticed before raising the PR and found quite peculiar but could not come to a conclusion on why this happens. Chris' charts are helpful for seeing this progression. This is a very different dynamic than is seen anywhere else in the model, and understanding what is driving it could inform a better connection pattern. To find the root cause we would probably want to remove backout, remove the skip connection, remove the short/long window layer pattern and see which aspect is driving it. Initially I thought it was driven by backout where the skip from layer 3 to 7 was making it so that a negative contribution to the residual stream was leading to a positive contribution to prediction... but from some light ablations this idea seemed less likely. I have a separate hypothesis on why norm() fails on the bigram and value embeddings. The lambdas get applied indiscriminately to all tokens, so token magnitude is the main element that distinguishes if a value or bigram embedding is more important for a given token. norm() prevents an embedding from going to zero. In particular for bigram embeddings we want high collision indices to go to zero (which is why they are initialized to zero, so the high collision indices never disrupt loss). Maybe there is a form of normalization that just caps high magnitudes, while still enabling small ones to zero out. bigram_lambda is jittery on layer 0 because its added to a much smaller magnitude activation. Perhaps a simple /10 on layer zero would resolve. Interesting to see some crazy spikes in the lambdas when the window size first updates. |


Updates in PR:
Bigram hash code which runs on CPU during each dataloader iteration:
Model structure:
Recent discussion on Deepseek's Engram got me looking into hash embeddings, and this 2017 paper: https://arxiv.org/abs/1709.03933. Meaningful hash collisions are rare due to token distribution, so hashing seems good.
I did some of testing on more advanced ways to integrate the hashed embeddings with gating, trigrams, and multiple hashes. Very simple direct addition to the residual stream outperformed by a decent margin.
Ideas
Interestingly, the model now has more parameters than training tokens. Typically when a feature gets added the step count is reduced. At some point it may be better to drop an attn/mlp.
Now that we have an additional 5*50304 params getting communicated across GPUs and the bottleneck is overwhelmingly on comms, it could be worthwhile to build a sparse communication approach to only send embeds with nonzero grads. I did not look at the profiler here or try other comms orderings. Due to stepping Adam every other step, early step times alternate between 31ms and 37ms. (Could be other ideas like step bigram embed once every 4 steps)
There is an average of 0.4s extra due to a stall on step 6. Chris pointed this out and I believe Varun has been working on a fix. I may have also gotten a worse machine here that was making the issue more pronounced.
Timing and Validation
retiming prior record: 104.062 [104.129, 103.995]
If no changes, will merge at 99.3s to be consistent with 5.6s improvement over 104.9s.