[New Record] Removing H2D Movement in Bigram Hash Embedding (-0.25s)#216
Merged
ClassicLarry merged 5 commits intoKellerJordan:masterfrom Feb 11, 2026
Merged
Conversation
Contributor
Author
|
Fixed merge conflicts, retimed, and updated the logs; still seeing the speedups (-0.25s) import scipy.stats
import torch
# First set (Retimed New Record group)
losses1 = [3.2769, 3.2807, 3.2797, 3.2765, 3.2777, 3.2771, 3.2787, 3.2796]
times1 = [95818, 95753, 95739, 95782, 95723, 95723, 95698, 95768]
# Second set (Retimed Current Record group)
losses2 = [3.2775, 3.2794, 3.2796, 3.2812, 3.2768, 3.2763]
times2 = [96016, 96002, 96001, 95940, 96128, 96113]
print("p=%.4f" % scipy.stats.ttest_1samp(losses1, 3.28, alternative="less").pvalue)
# p=0.0098
print("losses:", torch.std_mean(torch.tensor(losses1)))
# losses: (tensor(0.0015), tensor(3.2784))
print("time:", torch.std_mean(torch.tensor(times1, dtype=torch.float32)))
# time: (tensor(38.2884), tensor(95750.5000))
print("losses:", torch.std_mean(torch.tensor(losses2)))
# losses: (tensor(0.0019), tensor(3.2785))
print("time:", torch.std_mean(torch.tensor(times2, dtype=torch.float32)))
# time: (tensor(72.5911), tensor(96033.3359)) |
Contributor
|
For fun I ran the hash function through triton max autotuning and got: and I believe XBLOCK is chosen to be 1024. |
Collaborator
|
Will merge at -0.3 later this week. Nice find! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This simple optimization moves the bigram hash computation to gpu. It effectively eliminates one H2D transfer per batch.
As such, it also simplifies data loader and makes it so there is one fewer tensor to manage in the async pipeline.
Timing and Validation
This PR improves the time and does not impact loss on my hardware, tested in Docker.
Discussion
I did want to clarify whether my
torch.compileusage is completely okay (the flags in the decorator are in other decorators). This tiny patch was largely inspired during my first proper pass on the nanogpt repo, which I found through the Bigram Hash Embedding tweet. Excited to contribute going forward!