Commit bc39fbc
authored
fix(gdn_decode): widen pool indices to Int64 to prevent int32 element-offset overflow (#3230)
## 📌 Description
Fix `CUDA error: an illegal memory access was encountered` in
`flashinfer.gdn_decode.gated_delta_rule_decode_pretranspose` when the
pool+indices API is used with sufficiently large pool indices.
**Root cause.** The CuTe-DSL kernels compute the per-slot element offset
(`pool_idx * stride[0]`, or `(cache_idx * HV + i_hv) * stride[0]` for
bf16) using **Int32** arithmetic. Once the product exceeds `INT32_MAX`,
it wraps to a negative offset and the load/store hits an unmapped global
address.
Affects both backends the API can dispatch to (HV=32, V=K=128):
| backend | kernel | overflow threshold |
|---|---|---|
| fp32 pretranspose | `gdn_decode_kernel_{small,big}_batch_pretranspose`
| `pool_idx >= 3972` (vLLM padded slot stride 540 672) |
| bf16 fast path | `gdn_decode_bf16state_mtp_kernel` | `cache_idx >=
4096` (contiguous, `stride[0] = HV*V*K = 524 288`) |
Discovered while integrating the kernel into vLLM's GDN decode path for
**Qwen3.5-class models**.
**Fix.** Widen the pool indices to Int64 immediately after they are
read; downstream offsets in `cute.local_tile(...)` / `h0_source[(...)]`
then promote to Int64 and cannot wrap:
```python
# fp32 pretranspose (small + big batch)
pool_idx = cutlass.Int64(h0_indices[i_n])
out_pool_idx = cutlass.Int64(h0_out_indices[i_n])
# bf16 MTP — propagates Int64 through flat_state_idx,
# flat_write_idx, and the intermediate-states cache's flat_idx.
cache_idx = cutlass.Int64(h0_indices[i_n])
write_cache_idx = cutlass.Int64(h0_out_indices[i_n])
```
## 🔍 Related Issues
Same class of bug as
[#3005](#3005) /
[#3007](#3007) (rmsnorm
stride overflow), in a different family of CuTe-DSL kernels.
## 🚀 Pull Request Checklist
### ✅ Pre-commit Checks
- [x] `pre-commit` installed and hooks installed.
- [x] `pre-commit run --files <changed files>` — all hooks pass.
## 🧪 Tests
- [x] Tests added.
- [x] All tests pass.
Added `tests/gdn/test_decode_pretranspose_noncontiguous_pool.py`:
- `test_decode_pretranspose_pool_int64_offset[3972, 8191]` — fp32
vLLM-padded pool (~8.6 / 17.7 GB).
- `test_decode_pretranspose_pool_int64_offset_bf16[4096, 4196]` — bf16
contiguous pool (~4.3 GB).
Both compare the pool path against a gather + direct-state reference
(numerical correctness, not just non-crashing) and assert the in-place
state update matches. VRAM-based skip when free memory is insufficient.
**Verified on NVIDIA B200 (SM100):** all 4 new tests crash without the
fix and pass with it; existing `pretranspose` and `bf16_state` tests in
`tests/gdn/test_decode_delta_rule.py` continue to pass.
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **Bug Fixes**
* Resolved integer overflow in GPU decode kernels by switching pool- and
state-index arithmetic to 64-bit, preventing wraparound and
out-of-bounds addressing for large pools and batches.
* Ensured consistent 64-bit handling across all decode paths and
negative-index clamping.
* **Tests**
* Added GPU regression tests covering large-pool overflow scenarios for
FP32 and BF16 fast paths, with device-capacity guards to avoid OOM on
low-VRAM systems.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>1 parent 4381afc commit bc39fbc
3 files changed
Lines changed: 403 additions & 16 deletions
File tree
- flashinfer/gdn_kernels
- tests/gdn
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
163 | 163 | | |
164 | 164 | | |
165 | 165 | | |
166 | | - | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
167 | 180 | | |
168 | 181 | | |
169 | 182 | | |
170 | 183 | | |
171 | 184 | | |
172 | | - | |
| 185 | + | |
173 | 186 | | |
174 | | - | |
| 187 | + | |
175 | 188 | | |
176 | 189 | | |
177 | 190 | | |
| |||
225 | 238 | | |
226 | 239 | | |
227 | 240 | | |
228 | | - | |
| 241 | + | |
229 | 242 | | |
230 | 243 | | |
231 | 244 | | |
| |||
652 | 665 | | |
653 | 666 | | |
654 | 667 | | |
655 | | - | |
| 668 | + | |
| 669 | + | |
| 670 | + | |
| 671 | + | |
| 672 | + | |
| 673 | + | |
| 674 | + | |
| 675 | + | |
| 676 | + | |
| 677 | + | |
656 | 678 | | |
657 | 679 | | |
658 | 680 | | |
| |||
780 | 802 | | |
781 | 803 | | |
782 | 804 | | |
783 | | - | |
| 805 | + | |
| 806 | + | |
| 807 | + | |
| 808 | + | |
| 809 | + | |
| 810 | + | |
| 811 | + | |
| 812 | + | |
| 813 | + | |
| 814 | + | |
| 815 | + | |
| 816 | + | |
784 | 817 | | |
785 | 818 | | |
786 | 819 | | |
| |||
824 | 857 | | |
825 | 858 | | |
826 | 859 | | |
827 | | - | |
| 860 | + | |
828 | 861 | | |
829 | 862 | | |
830 | 863 | | |
| |||
835 | 868 | | |
836 | 869 | | |
837 | 870 | | |
838 | | - | |
| 871 | + | |
839 | 872 | | |
840 | | - | |
| 873 | + | |
841 | 874 | | |
842 | 875 | | |
843 | 876 | | |
| |||
1169 | 1202 | | |
1170 | 1203 | | |
1171 | 1204 | | |
1172 | | - | |
| 1205 | + | |
| 1206 | + | |
| 1207 | + | |
| 1208 | + | |
| 1209 | + | |
| 1210 | + | |
| 1211 | + | |
| 1212 | + | |
| 1213 | + | |
| 1214 | + | |
| 1215 | + | |
| 1216 | + | |
1173 | 1217 | | |
1174 | 1218 | | |
1175 | 1219 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
134 | 134 | | |
135 | 135 | | |
136 | 136 | | |
137 | | - | |
138 | | - | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
139 | 148 | | |
140 | 149 | | |
141 | | - | |
| 150 | + | |
142 | 151 | | |
143 | 152 | | |
144 | 153 | | |
| |||
442 | 451 | | |
443 | 452 | | |
444 | 453 | | |
445 | | - | |
446 | | - | |
| 454 | + | |
| 455 | + | |
| 456 | + | |
| 457 | + | |
| 458 | + | |
| 459 | + | |
| 460 | + | |
| 461 | + | |
447 | 462 | | |
448 | 463 | | |
449 | | - | |
| 464 | + | |
450 | 465 | | |
451 | 466 | | |
452 | 467 | | |
| |||
0 commit comments