New WR: sparse bigram gradient comms (-0.6 seconds)#221
Merged
ClassicLarry merged 3 commits intoKellerJordan:masterfrom Feb 16, 2026
Merged
New WR: sparse bigram gradient comms (-0.6 seconds)#221ClassicLarry merged 3 commits intoKellerJordan:masterfrom
ClassicLarry merged 3 commits intoKellerJordan:masterfrom
Conversation
Contributor
|
FYI, in case it's helpful for catching things up--I integrated the PRs prior to yours into your code, and re-ran it for four runs. |
Collaborator
|
Merging at 91.0 (-0.3s) based on re-timing. (This one undoes the 0.3s gain from moving the bigram to GPU) I will do a merge afterwards to clean up the merge conflicts. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Title timing is a mistake, its closer to -0.75 seconds.
This is an update & cleanup of #219. No ML changes.
This PR contains three changes:
The sparse comms implementation saves much bandwidth (even at largest batch size, the amount of communication is about the same as the embedding and lm_head layers (for the scatter)) at the cost of extra compute to reconstitute the rank-local gradient, which is hidden by overlap with other communication.
Moving the bigram index calculation directly into a pinned tensor saved a lot of time in the forward pass as
.to(device, non_blocking=True)was very slow. Since I need the index on the CPU as part of the sparse communication scheme, this is mutually exclusive with #216 unfortunately, though I think much of the gain in #216 is already folded into moving to a pinned tensor.Ablation: a build with only changes 2 & 3 was about 500ms faster than baseline (hard-coded
_sparse_comms_active()toreturn False).The scatter-gather order is a bit under-explored: using a profiler, it's easy to see that overlap is not perfect despite there being enough transfers queued. I found a configuration that improved timing by another 100ms (there should be about an extra 200ms on top of it) but for some reason loss was increased.
Will upload the logs in a bit. See #219 for more details of the sparse communication algorithm until I move them here to the readme