Skip to content

[New Record] Removing H2D Movement in Bigram Hash Embedding (-0.25s)#216

Merged
ClassicLarry merged 5 commits intoKellerJordan:masterfrom
dhruvji:dhruv/bigram-hash-gpu
Feb 11, 2026
Merged

[New Record] Removing H2D Movement in Bigram Hash Embedding (-0.25s)#216
ClassicLarry merged 5 commits intoKellerJordan:masterfrom
dhruvji:dhruv/bigram-hash-gpu

Conversation

@dhruvji
Copy link
Copy Markdown
Contributor

@dhruvji dhruvji commented Jan 31, 2026

This simple optimization moves the bigram hash computation to gpu. It effectively eliminates one H2D transfer per batch.

As such, it also simplifies data loader and makes it so there is one fewer tensor to manage in the async pipeline.

Timing and Validation

This PR improves the time and does not impact loss on my hardware, tested in Docker.

import scipy.stats
import torch

# First set (New Record group)
losses1 = [3.2771, 3.2781, 3.2787, 3.2786, 3.2761, 3.2777, 3.2783]
times1 = [97887, 97942, 98031, 97906, 97840, 97901, 97843] # in ms

# Second set (Prev Record group)
losses2 = [3.2778, 3.2780, 3.2777, 3.2818, 3.2769, 3.2780]
times2 = [98036, 98178, 98158, 98135, 98137, 98146] # in ms

print("=== New Record Set ===")
print("p=%.4f" % scipy.stats.ttest_1samp(losses1, 3.28, alternative="less").pvalue)
# p=0.0004
print("losses:", torch.std_mean(torch.tensor(losses1)))
# losses: (tensor(0.0009), tensor(3.2778))
print("time:", torch.std_mean(torch.tensor(times1, dtype=torch.float32)))
# time: (tensor(65.3489), tensor(97907.1406))

print("\n=== Prev Record Set ===")
print("losses:", torch.std_mean(torch.tensor(losses2)))
# losses: (tensor(0.0017), tensor(3.2784))
print("time:", torch.std_mean(torch.tensor(times2, dtype=torch.float32)))
# time: (tensor(49.4719), tensor(98131.6641))

Discussion

I did want to clarify whether my torch.compile usage is completely okay (the flags in the decorator are in other decorators). This tiny patch was largely inspired during my first proper pass on the nanogpt repo, which I found through the Bigram Hash Embedding tweet. Excited to contribute going forward!

@dhruvji dhruvji changed the title [New Record] Removing H2D Movement in Bigram Hash Embedding (-0.2s) [New Record] Removing H2D Movement in Bigram Hash Embedding (-0.25s) Feb 1, 2026
@dhruvji
Copy link
Copy Markdown
Contributor Author

dhruvji commented Feb 1, 2026

Fixed merge conflicts, retimed, and updated the logs; still seeing the speedups (-0.25s)

import scipy.stats
import torch

# First set (Retimed New Record group)
losses1 = [3.2769, 3.2807, 3.2797, 3.2765, 3.2777, 3.2771, 3.2787, 3.2796]
times1 = [95818, 95753, 95739, 95782, 95723, 95723, 95698, 95768]

# Second set (Retimed Current Record group)
losses2 = [3.2775, 3.2794, 3.2796, 3.2812, 3.2768, 3.2763]
times2 = [96016, 96002, 96001, 95940, 96128, 96113]

print("p=%.4f" % scipy.stats.ttest_1samp(losses1, 3.28, alternative="less").pvalue)
# p=0.0098
print("losses:", torch.std_mean(torch.tensor(losses1)))
# losses: (tensor(0.0015), tensor(3.2784))
print("time:", torch.std_mean(torch.tensor(times1, dtype=torch.float32)))
# time: (tensor(38.2884), tensor(95750.5000))

print("losses:", torch.std_mean(torch.tensor(losses2)))
# losses: (tensor(0.0019), tensor(3.2785))
print("time:", torch.std_mean(torch.tensor(times2, dtype=torch.float32)))
# time: (tensor(72.5911), tensor(96033.3359))

@varunneal
Copy link
Copy Markdown
Contributor

For fun I ran the hash function through triton max autotuning and got:

@triton.jit
def triton_poi_fused_arange_bitwise_xor_eq_mul_new_full_new_zeros_remainder_roll_where_0(
    in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr
):
    xnumel = 32768
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    x0 = xindex

    tmp3 = tl.load(in_ptr0 + (x0), None)
    tmp6 = tl.load(in_ptr0 + (((32767 + x0) % 32768)), None, eviction_policy='evict_last')

    tmp2 = x0 == tl.full([1], 0, tl.int64)
    tmp8 = tl.where(tmp2, tl.full([1], 0, tl.int32), tmp6)

    tmp5  = tmp3 * tl.full([1], 36313, tl.int32)
    tmp10 = tmp8 * tl.full([1], 27191, tl.int32)
    tmp11 = tmp5 ^ tmp10

    tmp12 = tl.full([1], 50303, tl.int32)
    tmp13 = (tmp11 % tmp12)

    tmp14 = tmp13 != tl.full([1], 0, tl.int32)
    tmp15 = (libdevice.signbit(tmp13) != 0) if (tmp13).dtype is tl.float32 else tmp13 < 0
    tmp16 = (libdevice.signbit(tmp12) != 0) if (tmp12).dtype is tl.float32 else tmp12 < 0
    tmp18 = tmp14 & (tmp15 != tmp16)

    tmp20 = tl.where(tmp18, tmp13 + tmp12, tmp13)
    tmp21 = tl.where(tmp2, tmp12, tmp20)

    tl.store(out_ptr0 + (x0), tmp21, None)

and I believe XBLOCK is chosen to be 1024.

@ClassicLarry
Copy link
Copy Markdown
Collaborator

Will merge at -0.3 later this week. Nice find!

@ClassicLarry ClassicLarry merged commit 3eb3a4d into KellerJordan:master Feb 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants