Skip to content

New Record: Bigram Hash Embedding (-5.6s, -165 steps)#201

Merged
ClassicLarry merged 1 commit intoKellerJordan:masterfrom
ClassicLarry:BigramHashEmbedding
Jan 20, 2026
Merged

New Record: Bigram Hash Embedding (-5.6s, -165 steps)#201
ClassicLarry merged 1 commit intoKellerJordan:masterfrom
ClassicLarry:BigramHashEmbedding

Conversation

@ClassicLarry
Copy link
Copy Markdown
Collaborator

Updates in PR:

  • Bigram Hash Embedding (-5.1s)
  • Fix partial key offset to apply to stationary dims instead of 50/50 stationary/rotating (-0.5s) [bug was introduced in paired head PR]
  • Reduce step count from 1765 to 1600
  • Increase cooldown frac from 0.5 to 0.55 (I find it works to increase this as step count decreases)

Bigram hash code which runs on CPU during each dataloader iteration:

args.bigram_vocab_size = 5 * vocab_size
def get_bigram_hash(x):
    """
    Computes bigram hash for each position using [prev_token, curr_token].
    Multiply by arbitary large ints to get even spread over int32 range.
    Position 0 is mapped to the reserved index (vocab_size - 1).
    BOS_tokens within the batch will hash based on last token of prior doc. Masking this ran slower and showed no improvement.
    """
    rand_int_1 = 36313
    rand_int_2 = 27191
    mod = args.bigram_vocab_size-1
    x = x.to(torch.int32).clone()
    x[0] = mod
    x[1:] = torch.bitwise_xor(rand_int_1 * x[1:], rand_int_2 * x[:-1]) % mod
    return x

Model structure:

# model args
embed = nn.Embedding(vocab_size, model_dim)
bigram_embed = nn.Embedding(bigram_vocab_size, model_dim).zero_init()
x0_lambdas = nn.Parameter(torch.zeros(num_layers))
bigram_lambdas = nn.Parameter(0.1*torch.ones(num_layers))

# model forward pass
x = x0 = norm(nn.Embedding(input))
x0_bigram = nn.Embedding(get_bigram_hash(input))
for i in range(num_layers):
    x = x + x0_lambdas[i] * x0 + bigram_lambdas[i] * x0_bigram
    x = block(x)
logits = lm_head(x)

Recent discussion on Deepseek's Engram got me looking into hash embeddings, and this 2017 paper: https://arxiv.org/abs/1709.03933. Meaningful hash collisions are rare due to token distribution, so hashing seems good.

I did some of testing on more advanced ways to integrate the hashed embeddings with gating, trigrams, and multiple hashes. Very simple direct addition to the residual stream outperformed by a decent margin.

Ideas

Interestingly, the model now has more parameters than training tokens. Typically when a feature gets added the step count is reduced. At some point it may be better to drop an attn/mlp.

Now that we have an additional 5*50304 params getting communicated across GPUs and the bottleneck is overwhelmingly on comms, it could be worthwhile to build a sparse communication approach to only send embeds with nonzero grads. I did not look at the profiler here or try other comms orderings. Due to stepping Adam every other step, early step times alternate between 31ms and 37ms. (Could be other ideas like step bigram embed once every 4 steps)

There is an average of 0.4s extra due to a stall on step 6. Chris pointed this out and I believe Varun has been working on a fix. I may have also gotten a worse machine here that was making the issue more pronounced.

Timing and Validation

import scipy.stats
import torch

losses = [3.2756, 3.2786, 3.2773, 3.2791, 3.2778, 3.277]
times = [98.065, 98.017, 99.201, 99.151, 98.112, 98.031]

print("p=%.4f" % scipy.stats.ttest_1samp(losses, 3.28, alternative="less").pvalue)
# p=0.0025

print("losses:", torch.std_mean(torch.tensor(losses)))
# losses: (tensor(0.0012), tensor(3.2776))

print("time:", torch.std_mean(torch.tensor(times)))
# time: (tensor(0.5794), tensor(98.4295))

retiming prior record: 104.062 [104.129, 103.995]

If no changes, will merge at 99.3s to be consistent with 5.6s improvement over 104.9s.

@alint77
Copy link
Copy Markdown

alint77 commented Jan 20, 2026

did you run ablation test on bigram vs Fix partial key offset?

@ClassicLarry
Copy link
Copy Markdown
Collaborator Author

did you run ablation test on bigram vs Fix partial key offset?

ya, fixing partial key offset dims was roughly 0.001 loss or half a sec.

@ClassicLarry
Copy link
Copy Markdown
Collaborator Author

Did some validations because the amount of loss decrease recently has made me more cautious on data leakage.
Hellaswag 1000 Qs on init: 22%, trained model: 32.5%.
If I replace input data with random ints, loss sits around 10.8 throughout training. (Data leakage could find a mechanism to still drive loss down).

@ClassicLarry ClassicLarry merged commit 93f0e6b into KellerJordan:master Jan 20, 2026
@trianxy
Copy link
Copy Markdown
Contributor

trianxy commented Jan 20, 2026

Wow! 5% improvement at this point is big. Congrats.

Having Bigram Hash Embedding yield such an improvement is exciting. Also, it's interesting that trigrams or multiple hashes didn't show additional improvement. Sounds to me like there's some interesting stuff to understand there.

@ClassicLarry
Copy link
Copy Markdown
Collaborator Author

Wow! 5% improvement at this point is big. Congrats.

Having Bigram Hash Embedding yield such an improvement is exciting. Also, it's interesting that trigrams or multiple hashes didn't show additional improvement. Sounds to me like there's some interesting stuff to understand there.

I want to clarify that only my particular implementation of trigrams and multiple hashes didn't help. In an execution driven field, negative results are close to meaningless. I expect there are likely extensions that will be beneficial with the right implementation nuances.

@shenberg
Copy link
Copy Markdown
Contributor

shenberg commented Jan 21, 2026

I wrote an initial sparse-comms implementation for the gradient reduce-scatter (param updates are dense at the moment). Works using the all_to_all_single strategy (I believe it's the horovod embedding all-to-all strategy):

  1. each rank sends to the others the count of tokens (unique) it saw where the gradient is owned by them.
  2. Send the actual token ids to every rank, async, overlapping with the forward() call
  3. In the optimizer reduce-scatter step, we send the actual gradient rows for token ids we saw to the appropriate rank
  4. before using the gradients, each rank waits on receiving the gradient rows and indexes from all the other ranks and reconstructs its gradient shard by adding everything up.

This works but slows down run-time by about a second and a half since implementation is sub-optimal. I suspect that even with a good implementation, it might be worth it only for small batch sizes, as for large ones I'm expecting about 65% sparsity. I'll throw some torch.compile() at it and profile once there are spot instances available.

Update: torch.compile gets it to be about 200ms slower than dense baseline. I'll look into it a bit more.

@chrisjmccormick
Copy link
Copy Markdown
Contributor

This was absolutely incredible @ LarryTheLegend. 🔥
The big comms cost just creates a fun opportunity for optimization--I'm looking forward to experimenting.

@kroggen
Copy link
Copy Markdown

kroggen commented Jan 26, 2026

Wow, this is impressive!
I was expecting it to take more time to learn with more parameters, but instead it learns faster!
Very cool and interesting

@kroggen
Copy link
Copy Markdown

kroggen commented Jan 26, 2026

I wonder if the "Smear token embeddings forward 1 position" (PR #130) can/should be removed now that the model is using bigram embedding

multiple attention heads were consistently attending to the prior token

@kroggen
Copy link
Copy Markdown

kroggen commented Jan 27, 2026

I just confirmed: removing the smear token embeddings and smear gate makes the training faster, all because the Bigram Hash Embeddings are sufficient and powerful enough to overcome the effects of smear gates

I will post the PR tomorrow

@ClassicLarry
Copy link
Copy Markdown
Collaborator Author

I just confirmed: removing the smear token embeddings and smear gate makes the training faster, all because the Bigram Hash Embeddings are sufficient and powerful enough to overcome the effects of smear gates

I will post the PR tomorrow

I had tested this before I raised the Pr but found that 1/5 runs would randomly loss spike. But perhaps my tests were just atypically unlucky.

@chrisjmccormick
Copy link
Copy Markdown
Contributor

May or may not be related, but layer 0 bigram lambda is noisy:

image

I've implemented a feature for the optimizer which will allow us to separate out all of the scalars--the optimizer handles packing and unpacking them for us, instead of doing that instead the model. That lets us specify betas and lrs for everything, and even break off individual parameters like "bigram_lambda_l0" if we want.

@chrisjmccormick
Copy link
Copy Markdown
Contributor

I've used it to finally get the resid lambdas under control. I don't know if I'll be able to get a record out of any of the tuning, so my plan was to do more of an "added feature" PR for it.

image

@kroggen
Copy link
Copy Markdown

kroggen commented Jan 27, 2026

I had tested this before I raised the Pr but found that 1/5 runs would randomly loss spike. But perhaps my tests were just atypically unlucky.

Indeed, you are right. It was just a lucky run on a machine with a single H100. Tested now many times on 4xH100 and the training loss was not smaller when removing the smear embeddings.

Just note that new scalars were added to self.scalars on this PR but the padding was not updated.
This line remains the same:

pad = (-num_layers * 3 - 3) % dist.get_world_size() # updated: 3*num_layers instead of 4*

It should be:

pad = -(num_layers * 4 + 3) % dist.get_world_size()

@ClassicLarry
Copy link
Copy Markdown
Collaborator Author

Good catch on the padding, which is redundant now that all_reduce is being used on the scalars.

3 of the bigram lambdas go negative, which I had noticed before raising the PR and found quite peculiar but could not come to a conclusion on why this happens. Chris' charts are helpful for seeing this progression. This is a very different dynamic than is seen anywhere else in the model, and understanding what is driving it could inform a better connection pattern. To find the root cause we would probably want to remove backout, remove the skip connection, remove the short/long window layer pattern and see which aspect is driving it. Initially I thought it was driven by backout where the skip from layer 3 to 7 was making it so that a negative contribution to the residual stream was leading to a positive contribution to prediction... but from some light ablations this idea seemed less likely.

I have a separate hypothesis on why norm() fails on the bigram and value embeddings. The lambdas get applied indiscriminately to all tokens, so token magnitude is the main element that distinguishes if a value or bigram embedding is more important for a given token. norm() prevents an embedding from going to zero. In particular for bigram embeddings we want high collision indices to go to zero (which is why they are initialized to zero, so the high collision indices never disrupt loss). Maybe there is a form of normalization that just caps high magnitudes, while still enabling small ones to zero out.

bigram_lambda is jittery on layer 0 because its added to a much smaller magnitude activation. Perhaps a simple /10 on layer zero would resolve. Interesting to see some crazy spikes in the lambdas when the window size first updates.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants