diff --git a/flashinfer/gdn_kernels/blackwell/gdn_prefill.py b/flashinfer/gdn_kernels/blackwell/gdn_prefill.py index 82dcc72b07..aafcc6716a 100644 --- a/flashinfer/gdn_kernels/blackwell/gdn_prefill.py +++ b/flashinfer/gdn_kernels/blackwell/gdn_prefill.py @@ -33,9 +33,10 @@ import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute -import cutlass.utils as cutlass_utils from cutlass.cute.runtime import from_dlpack +from flashinfer.cute_dsl.utils import get_num_sm + from .gated_delta_net_chunked import GatedDeltaNetChunkedKernel @@ -157,9 +158,8 @@ def chunk_gated_delta_rule_sm100( if "compiled" not in cache: # --- First call: compile the kernel --- - hardware_info = cutlass_utils.HardwareInfo() - num_sm = hardware_info.get_max_active_clusters(1) - max_active_clusters = hardware_info.get_max_active_clusters(1) + num_sm = get_num_sm(q.device) + max_active_clusters = num_sm gdn = GatedDeltaNetChunkedKernel( io_dtype=io_dtype,