Skip to content

Commit 4dba29f

Browse files
blake-sncclaude
andauthored
feat: add SM120 fmha_v2 kernels to AOT pip wheel builds (#2885)
## Summary `gen_trtllm_fmha_v2_sm120_module()` exists in `jit/attention/modules.py` and the JIT runtime path (`generate_kernels.py`) already dispatches to it correctly. However, `aot.py`'s `gen_all_modules()` — which drives the pip wheel AOT build — was missing it from the `has_sm120 or has_sm121` section. This means SM120/SM121 devices using a pip wheel would never get the fmha_v2 SM120 kernels compiled into the wheel, and would have to fall back to slower paths. **Fix:** Add `gen_trtllm_fmha_v2_sm120_module()` to the `has_sm120 or has_sm121` block in `aot.py`, alongside the other SM120 modules (fused MOE, GEMM, FP4 quantization). No behavior change for JIT users; only affects AOT pip wheel builds. Addresses the AOT gap noted in #2555. Contributed by Second Nature Computing (https://joinsecondnature.com) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Chores** * Expanded optimized inference module support for SM120 and SM121 GPUs to include attention kernels in addition to existing fused MoE and GEMM optimizations. * Increased runtime coverage and readiness for attention-heavy workloads on those architectures, improving performance consistency for models using attention. [![Review Change Stack](https://storage.googleapis.com/coderabbit_public_assets/review-stack-in-coderabbit-ui.svg)](https://app.coderabbit.ai/change-stack/flashinfer-ai/flashinfer/pull/2885) <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 6885e76 commit 4dba29f

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

flashinfer/aot.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
gen_single_decode_module,
4545
gen_single_prefill_module,
4646
gen_trtllm_gen_fmha_module,
47+
gen_trtllm_fmha_v2_sm120_module,
4748
)
4849
from .jit.cascade import gen_cascade_module
4950
from .jit.cpp_ext import get_cuda_version
@@ -533,13 +534,14 @@ def gen_all_modules(
533534
if has_sm121:
534535
jit_specs.append(gen_fp4_quantization_sm121_module())
535536
if has_sm120 or has_sm121:
536-
# SM120 and SM121 share the same CUTLASS kernels for fused MOE and GEMM.
537+
# SM120 and SM121 share the same kernels for fused MOE, GEMM, and attention.
537538
# The SM120 module generators use supported_major_versions=[12] which
538539
# compiles for all SM12x targets.
539540
jit_specs.append(gen_cutlass_fused_moe_sm120_module())
540541
jit_specs.append(gen_gemm_sm120_module())
541542
jit_specs.append(gen_gemm_sm120_module_cutlass_fp4())
542543
jit_specs.append(gen_gemm_sm120_module_cutlass_mxfp8())
544+
jit_specs.append(gen_trtllm_fmha_v2_sm120_module())
543545
if has_sm120f:
544546
jit_specs.append(gen_fp4_quantization_sm120f_module())
545547

0 commit comments

Comments
 (0)