Skip to content

Commit bc39fbc

Browse files
authored
fix(gdn_decode): widen pool indices to Int64 to prevent int32 element-offset overflow (#3230)
## 📌 Description Fix `CUDA error: an illegal memory access was encountered` in `flashinfer.gdn_decode.gated_delta_rule_decode_pretranspose` when the pool+indices API is used with sufficiently large pool indices. **Root cause.** The CuTe-DSL kernels compute the per-slot element offset (`pool_idx * stride[0]`, or `(cache_idx * HV + i_hv) * stride[0]` for bf16) using **Int32** arithmetic. Once the product exceeds `INT32_MAX`, it wraps to a negative offset and the load/store hits an unmapped global address. Affects both backends the API can dispatch to (HV=32, V=K=128): | backend | kernel | overflow threshold | |---|---|---| | fp32 pretranspose | `gdn_decode_kernel_{small,big}_batch_pretranspose` | `pool_idx >= 3972` (vLLM padded slot stride 540 672) | | bf16 fast path | `gdn_decode_bf16state_mtp_kernel` | `cache_idx >= 4096` (contiguous, `stride[0] = HV*V*K = 524 288`) | Discovered while integrating the kernel into vLLM's GDN decode path for **Qwen3.5-class models**. **Fix.** Widen the pool indices to Int64 immediately after they are read; downstream offsets in `cute.local_tile(...)` / `h0_source[(...)]` then promote to Int64 and cannot wrap: ```python # fp32 pretranspose (small + big batch) pool_idx = cutlass.Int64(h0_indices[i_n]) out_pool_idx = cutlass.Int64(h0_out_indices[i_n]) # bf16 MTP — propagates Int64 through flat_state_idx, # flat_write_idx, and the intermediate-states cache's flat_idx. cache_idx = cutlass.Int64(h0_indices[i_n]) write_cache_idx = cutlass.Int64(h0_out_indices[i_n]) ``` ## 🔍 Related Issues Same class of bug as [#3005](#3005) / [#3007](#3007) (rmsnorm stride overflow), in a different family of CuTe-DSL kernels. ## 🚀 Pull Request Checklist ### ✅ Pre-commit Checks - [x] `pre-commit` installed and hooks installed. - [x] `pre-commit run --files <changed files>` — all hooks pass. ## 🧪 Tests - [x] Tests added. - [x] All tests pass. Added `tests/gdn/test_decode_pretranspose_noncontiguous_pool.py`: - `test_decode_pretranspose_pool_int64_offset[3972, 8191]` — fp32 vLLM-padded pool (~8.6 / 17.7 GB). - `test_decode_pretranspose_pool_int64_offset_bf16[4096, 4196]` — bf16 contiguous pool (~4.3 GB). Both compare the pool path against a gather + direct-state reference (numerical correctness, not just non-crashing) and assert the in-place state update matches. VRAM-based skip when free memory is insufficient. **Verified on NVIDIA B200 (SM100):** all 4 new tests crash without the fix and pass with it; existing `pretranspose` and `bf16_state` tests in `tests/gdn/test_decode_delta_rule.py` continue to pass. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Resolved integer overflow in GPU decode kernels by switching pool- and state-index arithmetic to 64-bit, preventing wraparound and out-of-bounds addressing for large pools and batches. * Ensured consistent 64-bit handling across all decode paths and negative-index clamping. * **Tests** * Added GPU regression tests covering large-pool overflow scenarios for FP32 and BF16 fast paths, with device-capacity guards to avoid OOM on low-VRAM systems. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
1 parent 4381afc commit bc39fbc

3 files changed

Lines changed: 403 additions & 16 deletions

File tree

flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -163,15 +163,28 @@ def gdn_decode_bf16state_mtp_ilp4_kernel(
163163
i_n = tmp // HV
164164
i_h = i_hv // (HV // H)
165165

166-
cache_idx = h0_indices[i_n]
166+
# Widen pool indices to Int64 BEFORE they multiply ``stride[0]`` of
167+
# ``h0_source``. The reshape ``[pool_size, HV, V, K] -> [pool_size * HV,
168+
# V, K]`` (BF16) gives ``stride[0] = V * K = 16384`` elements, so the
169+
# downstream offset ``(cache_idx * HV + i_hv) * 16384`` reaches 2**31 at
170+
# ``cache_idx >= ceil(2**31 / (HV * V * K)) = 4096`` (HV=32, V=K=128;
171+
# 2048 at HV=64). Past that boundary the Int32 multiplication wraps to
172+
# a negative offset and ``cute.local_tile(h0_source, ...)`` issues a
173+
# load/store to an unmapped global address. Propagating Int64 through
174+
# ``flat_state_idx`` / ``flat_write_state_idx`` (computed below) keeps
175+
# the offset multiplication 64-bit. See
176+
# ``tests/gdn/test_decode_pretranspose_noncontiguous_pool.py
177+
# ::test_decode_pretranspose_pool_int64_offset_bf16`` for the
178+
# regression test.
179+
cache_idx = cutlass.Int64(h0_indices[i_n])
167180
if cutlass.const_expr(same_pool):
168181
# Single-pool: alias write to read; nvcc DCEs the write-side LDG /
169182
# IMAD / local_tile entirely in this compile path.
170183
write_cache_idx = cache_idx
171184
else:
172-
write_cache_idx = h0_out_indices[i_n]
185+
write_cache_idx = cutlass.Int64(h0_out_indices[i_n])
173186
if write_cache_idx < 0:
174-
write_cache_idx = cutlass.Int32(0)
187+
write_cache_idx = cutlass.Int64(0)
175188

176189
r_A_log = cutlass.Float32(A_log[i_hv])
177190
r_dt_bias = cutlass.Float32(dt_bias[i_hv])
@@ -225,7 +238,7 @@ def gdn_decode_bf16state_mtp_ilp4_kernel(
225238
)
226239

227240
if cache_idx < 0:
228-
cache_idx = cutlass.Int32(0)
241+
cache_idx = cutlass.Int64(0)
229242

230243
if cache_idx >= 0:
231244
k_start = lane_in_group * vec_size
@@ -652,7 +665,16 @@ def gdn_decode_bf16state_mtp_ilp4_kernel(
652665
# initial_state_indices points at slots >= B (i.e. any
653666
# realistic pool_size > B serving config). Fix mirrors
654667
# upstream PR #3145.
655-
flat_idx = i_n * T * HV + i_t * HV + i_hv
668+
# Defense-in-depth: widen to Int64 so the offset
669+
# ``flat_idx * stride[0]`` (= ``flat_idx * V * K``
670+
# = ``flat_idx * 16384`` BF16 elements) into the
671+
# batch-scoped intermediate-states cache cannot
672+
# wrap. This kernel is only reached at
673+
# ``B * HV <= 128`` so the flat_idx itself stays
674+
# well below the wrap threshold, but matching the
675+
# wide_vec kernel below keeps the two paths
676+
# bit-equivalent at large pool sizes.
677+
flat_idx = cutlass.Int64(i_n) * T * HV + i_t * HV + i_hv
656678
ita = cute.local_tile(
657679
intermediate_states,
658680
(1, 1, vec_size),
@@ -780,7 +802,18 @@ def gdn_wide_vec_kernel(
780802
i_n = tmp // HV
781803
i_h = i_hv // (HV // H)
782804

783-
cache_idx = h0_indices[i_n]
805+
# Widen pool indices to Int64 BEFORE they multiply ``stride[0]`` of
806+
# ``h0_source``. ``h0_source`` is reshaped to ``[pool_size * HV, V,
807+
# K]`` (BF16), so ``stride[0] = V * K = 16384`` elements; the
808+
# downstream offset ``(cache_idx * HV + i_hv) * 16384`` wraps int32
809+
# at ``cache_idx >= ceil(2**31 / (HV * V * K)) = 4096`` (HV=32) /
810+
# 2048 (HV=64). Propagating Int64 through ``flat_state_idx`` /
811+
# ``flat_write_state_idx`` keeps the ``cute.local_tile`` offset
812+
# arithmetic 64-bit at every reachable pool size. See
813+
# ``tests/gdn/test_decode_pretranspose_noncontiguous_pool.py
814+
# ::test_decode_pretranspose_pool_int64_offset_bf16`` for the
815+
# regression test.
816+
cache_idx = cutlass.Int64(h0_indices[i_n])
784817

785818
r_A_log = cutlass.Float32(A_log[i_hv])
786819
r_dt_bias = cutlass.Float32(dt_bias[i_hv])
@@ -824,7 +857,7 @@ def gdn_wide_vec_kernel(
824857
)
825858

826859
if cache_idx < 0:
827-
cache_idx = cutlass.Int32(0)
860+
cache_idx = cutlass.Int64(0)
828861

829862
# Split-pool write index: distinct slot to write the updated H state.
830863
# When same_pool=True (compile-time, set by the dispatcher whenever the
@@ -835,9 +868,9 @@ def gdn_wide_vec_kernel(
835868
if cutlass.const_expr(same_pool):
836869
write_cache_idx = cache_idx
837870
else:
838-
write_cache_idx = h0_out_indices[i_n]
871+
write_cache_idx = cutlass.Int64(h0_out_indices[i_n])
839872
if write_cache_idx < 0:
840-
write_cache_idx = cutlass.Int32(0)
873+
write_cache_idx = cutlass.Int64(0)
841874

842875
if cache_idx >= 0:
843876
flat_state_idx = cache_idx * HV + i_hv
@@ -1169,7 +1202,18 @@ def gdn_wide_vec_kernel(
11691202
# initial_state_indices points at slots >= B (i.e. any
11701203
# realistic pool_size > B serving config). Fix mirrors
11711204
# upstream PR #3145.
1172-
flat_idx = i_n * T * HV + i_t * HV + i_hv
1205+
# Widen to Int64: ``intermediate_states`` is
1206+
# reshaped to ``[B * T * HV, V, K]`` (BF16) with
1207+
# ``stride[0] = V * K = 16384`` elements. The
1208+
# offset ``flat_idx * 16384`` reaches 2**31 at
1209+
# ``flat_idx >= 131072``; with HV=64 + T=8 that's
1210+
# already hit at ``i_n >= 256`` (i.e. any
1211+
# production-scale MTP decode batch with caching
1212+
# enabled). Without the widening the Int32
1213+
# multiplication wraps and the
1214+
# ``cute.local_tile(intermediate_states, ...)``
1215+
# writes corrupt unrelated GMEM.
1216+
flat_idx = cutlass.Int64(i_n) * T * HV + i_t * HV + i_hv
11731217
it0 = cute.local_tile(
11741218
intermediate_states,
11751219
(1, 1, vec),

flashinfer/gdn_kernels/gdn_decode_pretranspose.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,20 @@ def gdn_decode_kernel_small_batch_pretranspose(
134134

135135
# Compute state index: use pool indexing if enabled.
136136
if cutlass.const_expr(use_pool_indexing):
137-
pool_idx = h0_indices[i_n]
138-
out_pool_idx = h0_out_indices[i_n]
137+
# Widen pool indices to Int64 BEFORE they multiply ``stride[0]`` of
138+
# ``h0_source``. With Int32 indices, the per-slot element offset
139+
# ``pool_idx * stride[0]`` silently wraps once it exceeds INT32_MAX,
140+
# which makes the kernel issue loads/stores to unmapped global
141+
# addresses (illegal memory access). For example, the padded slot
142+
# stride 540672 used by vLLM for Qwen3.5-class GDN models crosses the
143+
# threshold at pool_idx >= ceil(2**31 / 540672) = 3972. See
144+
# ``tests/gdn/test_decode_pretranspose_noncontiguous_pool.py::
145+
# test_decode_pretranspose_pool_int64_offset`` for the regression test.
146+
pool_idx = cutlass.Int64(h0_indices[i_n])
147+
out_pool_idx = cutlass.Int64(h0_out_indices[i_n])
139148
# Redirect negative write indices to null buffer (slot 0)
140149
if out_pool_idx < 0:
141-
out_pool_idx = cutlass.Int32(0)
150+
out_pool_idx = cutlass.Int64(0)
142151
else:
143152
pool_idx = 0
144153
out_pool_idx = 0
@@ -442,11 +451,17 @@ def gdn_decode_kernel_big_batch_pretranspose(
442451

443452
# Compute state index: use pool indexing if enabled.
444453
if cutlass.const_expr(use_pool_indexing):
445-
pool_idx = h0_indices[i_n]
446-
out_pool_idx = h0_out_indices[i_n]
454+
# Widen pool indices to Int64 BEFORE they multiply ``stride[0]`` of
455+
# ``h0_source``. With Int32 indices, the per-slot element offset
456+
# ``pool_idx * stride[0]`` silently wraps when it exceeds 2**31, which
457+
# makes the kernel issue loads/stores to unmapped global addresses
458+
# (illegal memory access). See the small-batch kernel above for the
459+
# full rationale.
460+
pool_idx = cutlass.Int64(h0_indices[i_n])
461+
out_pool_idx = cutlass.Int64(h0_out_indices[i_n])
447462
# Redirect negative write indices to null buffer (slot 0)
448463
if out_pool_idx < 0:
449-
out_pool_idx = cutlass.Int32(0)
464+
out_pool_idx = cutlass.Int64(0)
450465
else:
451466
pool_idx = 0
452467
out_pool_idx = 0

0 commit comments

Comments
 (0)