From 6aaab0ab462ec47b46b12a1ae779a13f988b28fb Mon Sep 17 00:00:00 2001 From: Artem Perevedentsev Date: Thu, 23 Apr 2026 18:24:28 +0300 Subject: [PATCH 1/5] [GDN] Enable FI Blackwell GDN prefill kernel Signed-off-by: Artem Perevedentsev --- docker/Dockerfile | 17 +++ setup.py | 10 ++ .../layers/mamba/gdn_linear_attn.py | 111 +++++++++++++----- 3 files changed, 110 insertions(+), 28 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index d76a2e986b7c..ef184650a9ef 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -176,6 +176,15 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.'); \ fi +# On CUDA 13 builds, install `nvidia-cutlass-dsl[cu13]` — it ships the +# CuTe-DSL libs that FlashInfer's Blackwell SM100 GDN prefill kernel +# (flashinfer-ai/flashinfer#3001) JIT-requires. +RUN --mount=type=cache,target=/root/.cache/uv \ + if [ "${CUDA_VERSION%%.*}" = "13" ]; then \ + uv pip install --python /opt/venv/bin/python3 \ + "nvidia-cutlass-dsl[cu13]>=4.4.2"; \ + fi + # Track PyTorch lib versions used during build and match in downstream instances. # We do this for both nightly and release so we can strip dependencies/*.txt as needed. # Otherwise library dependencies can upgrade/downgrade torch incorrectly. @@ -577,6 +586,14 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') && \ rm /tmp/requirements-cuda.txt /tmp/common.txt +# On CUDA 13 builds, install `nvidia-cutlass-dsl[cu13]` — it ships the +# CuTe-DSL libs that FlashInfer's Blackwell SM100 GDN prefill kernel +# (flashinfer-ai/flashinfer#3001) JIT-requires. +RUN --mount=type=cache,target=/root/.cache/uv \ + if [ "${CUDA_VERSION%%.*}" = "13" ]; then \ + uv pip install --system "nvidia-cutlass-dsl[cu13]>=4.4.2"; \ + fi + # Install FlashInfer JIT cache (requires CUDA-version-specific index URL) # https://docs.flashinfer.ai/installation.html # From versions.json: .flashinfer.version diff --git a/setup.py b/setup.py index c05280e40e78..dee26159a270 100644 --- a/setup.py +++ b/setup.py @@ -967,6 +967,16 @@ def _read_requirements(filename: str) -> list[str]: # vllm-flash-attn is built only for CUDA 12.x. # Skip for other versions. continue + if req.startswith("nvidia-cutlass-dsl") and cuda_major == "13": + # On CUDA 13 builds, pull in the `[cu13]` extra — it ships + # the CuTe-DSL libs that FlashInfer's Blackwell SM100 GDN + # prefill kernel (flashinfer-ai/flashinfer#3001) JIT- + # requires. + req = req.replace( + "nvidia-cutlass-dsl", + "nvidia-cutlass-dsl[cu13]", + 1, + ) modified_requirements.append(req) requirements = modified_requirements elif _is_hip(): diff --git a/vllm/model_executor/layers/mamba/gdn_linear_attn.py b/vllm/model_executor/layers/mamba/gdn_linear_attn.py index 7a0b54335baa..982f034022f4 100644 --- a/vllm/model_executor/layers/mamba/gdn_linear_attn.py +++ b/vllm/model_executor/layers/mamba/gdn_linear_attn.py @@ -67,6 +67,85 @@ logger = init_logger(__name__) +def _has_cutlass_dsl_cu13() -> bool: + """Whether the CUDA-13 CuTe-DSL shared libs are installed. + """ + try: + from importlib.metadata import distribution + except ImportError: + return False + try: + distribution("nvidia-cutlass-dsl-libs-cu13") + except Exception: + return False + return True + + +def _should_use_flashinfer_gdn_prefill( + backend: str, head_k_dim: int | None +) -> bool: + """Whether to use FlashInfer's GDN prefill kernel instead of the + Triton/FLA fallback. + + Requirements: + * ``requested in ["flashinfer", "auto"]``; + * ``platform == cuda``; + * one of the following: + - Hopper (SM90) — no further constraints; + - Blackwell (SM10.x) with ``head_k_dim == 128``, + ``nvidia-cutlass-dsl-libs-cu13`` installed, ``cuda_runtime >= 13``. + """ + if backend not in ["flashinfer", "auto"]: + return False + if not current_platform.is_cuda(): + return False + if current_platform.is_device_capability(90): + return True # Hopper — no further constraints. + if not current_platform.is_device_capability_family(100): + return False # Neither Hopper nor Blackwell. + if head_k_dim != 128: + return False + if not _has_cutlass_dsl_cu13(): + return False + return True + + +def _log_gdn_backend_decision( + backend: str, head_k_dim: int | None, use_flashinfer: bool +) -> None: + """Dump the inputs to the backend decision and the final choice.""" + is_cuda = current_platform.is_cuda() + platform = "cuda" if is_cuda else current_platform.device_name + cuda_runtime = torch.version.cuda or "n/a" + device_cap = ( + str(current_platform.get_device_capability()) if is_cuda else "n/a" + ) + cutlass_dsl_cu13_installed = _has_cutlass_dsl_cu13() + logger.info_once( + "GDN prefill backend inputs:\n" + " requested=%s\n" + " platform=%s, cuda_runtime=%s, device_capability=%s\n" + " nvidia_cutlass_dsl_libs_cu13_installed=%s\n" + " head_k_dim=%s", + backend, + platform, + cuda_runtime, + device_cap, + cutlass_dsl_cu13_installed, + head_k_dim, + scope="local", + ) + if use_flashinfer: + logger.info_once("Using FlashInfer GDN prefill kernel") + logger.info_once( + "FlashInfer GDN prefill kernel is JIT-compiled; first run may " + "take a while to compile. Set `--gdn-prefill-backend triton` to " + "avoid JIT compile time.", + ) + else: + logger.info_once("Using Triton/FLA GDN prefill kernel") + + def fi_chunk_gated_delta_rule( q: torch.Tensor, k: torch.Tensor, @@ -118,39 +197,15 @@ def fi_chunk_gated_delta_rule( @CustomOp.register("chunk_gated_delta_rule") class ChunkGatedDeltaRule(CustomOp): - def __init__(self) -> None: + def __init__(self, head_k_dim: int | None = None) -> None: super().__init__() additional_config = get_current_vllm_config().additional_config assert isinstance(additional_config, dict) backend_cfg = additional_config.get("gdn_prefill_backend", "auto") backend = str(backend_cfg).strip().lower() - supports_flashinfer = ( - current_platform.is_cuda() and current_platform.is_device_capability(90) - ) - - if backend == "flashinfer": - use_flashinfer = supports_flashinfer - if not use_flashinfer: - logger.warning_once( - "GDN prefill backend 'flashinfer' is selected but " - "cannot use this kernel on the current platform. " - "Falling back to Triton/FLA." - ) - elif backend == "triton": - use_flashinfer = False - else: - use_flashinfer = supports_flashinfer - - if use_flashinfer: - logger.info_once("Using FlashInfer GDN prefill kernel") - logger.info_once( - "FlashInfer GDN prefill kernel is JIT-compiled; first run may " - "take a while to compile. Set `--gdn-prefill-backend triton` to " - "avoid JIT compile time.", - ) - else: - logger.info_once("Using Triton/FLA GDN prefill kernel") + use_flashinfer = _should_use_flashinfer_gdn_prefill(backend, head_k_dim) + _log_gdn_backend_decision(backend, head_k_dim, use_flashinfer) self._forward_method = ( self.forward_cuda if use_flashinfer else self.forward_native @@ -380,7 +435,7 @@ def __init__( prefix=f"{prefix}.out_proj", ) - self.chunk_gated_delta_rule = ChunkGatedDeltaRule() + self.chunk_gated_delta_rule = ChunkGatedDeltaRule(head_k_dim=self.head_k_dim) self.enable_packed_recurrent_decode = ( envs.VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE ) From 560797f0202efd1a301726044a581b07332c5e07 Mon Sep 17 00:00:00 2001 From: Artem Perevedentsev Date: Thu, 23 Apr 2026 21:00:49 +0300 Subject: [PATCH 2/5] [GDN] Address review feedback from Gemini Signed-off-by: Artem Perevedentsev --- .../layers/mamba/gdn_linear_attn.py | 26 +++++++------------ vllm/platforms/interface.py | 19 ++++++++++++++ 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/layers/mamba/gdn_linear_attn.py b/vllm/model_executor/layers/mamba/gdn_linear_attn.py index 982f034022f4..664c1a11925d 100644 --- a/vllm/model_executor/layers/mamba/gdn_linear_attn.py +++ b/vllm/model_executor/layers/mamba/gdn_linear_attn.py @@ -67,20 +67,6 @@ logger = init_logger(__name__) -def _has_cutlass_dsl_cu13() -> bool: - """Whether the CUDA-13 CuTe-DSL shared libs are installed. - """ - try: - from importlib.metadata import distribution - except ImportError: - return False - try: - distribution("nvidia-cutlass-dsl-libs-cu13") - except Exception: - return False - return True - - def _should_use_flashinfer_gdn_prefill( backend: str, head_k_dim: int | None ) -> bool: @@ -105,7 +91,9 @@ def _should_use_flashinfer_gdn_prefill( return False # Neither Hopper nor Blackwell. if head_k_dim != 128: return False - if not _has_cutlass_dsl_cu13(): + if current_platform.get_cuda_runtime_major() < 13: + return False + if not current_platform.has_cutlass_dsl_cu13(): return False return True @@ -120,7 +108,7 @@ def _log_gdn_backend_decision( device_cap = ( str(current_platform.get_device_capability()) if is_cuda else "n/a" ) - cutlass_dsl_cu13_installed = _has_cutlass_dsl_cu13() + cutlass_dsl_cu13_installed = current_platform.has_cutlass_dsl_cu13() logger.info_once( "GDN prefill backend inputs:\n" " requested=%s\n" @@ -205,6 +193,12 @@ def __init__(self, head_k_dim: int | None = None) -> None: backend = str(backend_cfg).strip().lower() use_flashinfer = _should_use_flashinfer_gdn_prefill(backend, head_k_dim) + if backend == "flashinfer" and not use_flashinfer: + logger.warning_once( + "GDN prefill backend 'flashinfer' is selected but " + "cannot use this kernel on the current platform. " + "Falling back to Triton/FLA." + ) _log_gdn_backend_decision(backend, head_k_dim, use_flashinfer) self._forward_method = ( diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index c0d52620c086..202b92344b4e 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -359,6 +359,25 @@ def is_device_capability_family( return False return (current_capability.to_int() // 10) == (capability // 10) + @classmethod + def get_cuda_runtime_major(cls) -> int: + """Major ``torch.version.cuda`` version, or ``0`` if undetermined.""" + major = (torch.version.cuda or "0").split(".", 1)[0] + return int(major) if major.isdigit() else 0 + + @classmethod + def has_cutlass_dsl_cu13(cls) -> bool: + """Whether ``nvidia-cutlass-dsl-libs-cu13`` is installed.""" + try: + from importlib.metadata import distribution + except ImportError: + return False + try: + distribution("nvidia-cutlass-dsl-libs-cu13") + except Exception: + return False + return True + @classmethod def get_device_name(cls, device_id: int = 0) -> str: """Get the name of a device.""" From be4dfdb18a3dd8ee59ec5b378dd73d5e17629195 Mon Sep 17 00:00:00 2001 From: Artem Perevedentsev Date: Thu, 23 Apr 2026 21:28:54 +0300 Subject: [PATCH 3/5] [GDN] Fix ruff SIM103 and format Signed-off-by: Artem Perevedentsev --- vllm/model_executor/layers/mamba/gdn_linear_attn.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/mamba/gdn_linear_attn.py b/vllm/model_executor/layers/mamba/gdn_linear_attn.py index 664c1a11925d..60f8b4ea534e 100644 --- a/vllm/model_executor/layers/mamba/gdn_linear_attn.py +++ b/vllm/model_executor/layers/mamba/gdn_linear_attn.py @@ -67,9 +67,7 @@ logger = init_logger(__name__) -def _should_use_flashinfer_gdn_prefill( - backend: str, head_k_dim: int | None -) -> bool: +def _should_use_flashinfer_gdn_prefill(backend: str, head_k_dim: int | None) -> bool: """Whether to use FlashInfer's GDN prefill kernel instead of the Triton/FLA fallback. @@ -93,9 +91,7 @@ def _should_use_flashinfer_gdn_prefill( return False if current_platform.get_cuda_runtime_major() < 13: return False - if not current_platform.has_cutlass_dsl_cu13(): - return False - return True + return current_platform.has_cutlass_dsl_cu13() def _log_gdn_backend_decision( @@ -105,9 +101,7 @@ def _log_gdn_backend_decision( is_cuda = current_platform.is_cuda() platform = "cuda" if is_cuda else current_platform.device_name cuda_runtime = torch.version.cuda or "n/a" - device_cap = ( - str(current_platform.get_device_capability()) if is_cuda else "n/a" - ) + device_cap = str(current_platform.get_device_capability()) if is_cuda else "n/a" cutlass_dsl_cu13_installed = current_platform.has_cutlass_dsl_cu13() logger.info_once( "GDN prefill backend inputs:\n" From cb71f43420a40a072bfd631f84b650d416f13ff3 Mon Sep 17 00:00:00 2001 From: Artem Perevedentsev Date: Mon, 4 May 2026 11:42:07 +0300 Subject: [PATCH 4/5] [GDN] Delegate cu13 deps to FlashInfer's [cu13] extra Signed-off-by: Artem Perevedentsev --- docker/Dockerfile | 19 ++++++++++--------- docker/versions.json | 6 +++--- setup.py | 12 +++++------- 3 files changed, 18 insertions(+), 19 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index ef184650a9ef..5442f42b6a14 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -25,6 +25,7 @@ ARG CUDA_VERSION=13.0.0 ARG PYTHON_VERSION=3.12 ARG UBUNTU_VERSION=22.04 +ARG FLASHINFER_VERSION=0.6.8.post1 # By parameterizing the base images, we allow third-party to use their own # base images. One use case is hermetic builds with base images stored in @@ -94,6 +95,7 @@ FROM ${BUILD_BASE_IMAGE} AS base ARG CUDA_VERSION ARG PYTHON_VERSION +ARG FLASHINFER_VERSION ENV DEBIAN_FRONTEND=noninteractive @@ -176,13 +178,12 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.'); \ fi -# On CUDA 13 builds, install `nvidia-cutlass-dsl[cu13]` — it ships the -# CuTe-DSL libs that FlashInfer's Blackwell SM100 GDN prefill kernel -# (flashinfer-ai/flashinfer#3001) JIT-requires. +# `flashinfer-python` is already installed via requirements/cuda.txt above; +# this only activates its `[cu13]` extra (cu13 deps for the SM100 GDN kernel). RUN --mount=type=cache,target=/root/.cache/uv \ if [ "${CUDA_VERSION%%.*}" = "13" ]; then \ uv pip install --python /opt/venv/bin/python3 \ - "nvidia-cutlass-dsl[cu13]>=4.4.2"; \ + "flashinfer-python[cu13]==${FLASHINFER_VERSION}"; \ fi # Track PyTorch lib versions used during build and match in downstream instances. @@ -490,6 +491,7 @@ ARG PYTHON_VERSION ARG DEADSNAKES_MIRROR_URL ARG DEADSNAKES_GPGKEY_URL ARG GET_PIP_URL +ARG FLASHINFER_VERSION ENV DEBIAN_FRONTEND=noninteractive WORKDIR /vllm-workspace @@ -586,18 +588,17 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') && \ rm /tmp/requirements-cuda.txt /tmp/common.txt -# On CUDA 13 builds, install `nvidia-cutlass-dsl[cu13]` — it ships the -# CuTe-DSL libs that FlashInfer's Blackwell SM100 GDN prefill kernel -# (flashinfer-ai/flashinfer#3001) JIT-requires. +# `flashinfer-python` is already installed via requirements/cuda.txt above; +# this only activates its `[cu13]` extra (cu13 deps for the SM100 GDN kernel). RUN --mount=type=cache,target=/root/.cache/uv \ if [ "${CUDA_VERSION%%.*}" = "13" ]; then \ - uv pip install --system "nvidia-cutlass-dsl[cu13]>=4.4.2"; \ + uv pip install --system \ + "flashinfer-python[cu13]==${FLASHINFER_VERSION}"; \ fi # Install FlashInfer JIT cache (requires CUDA-version-specific index URL) # https://docs.flashinfer.ai/installation.html # From versions.json: .flashinfer.version -ARG FLASHINFER_VERSION=0.6.8.post1 RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system flashinfer-jit-cache==${FLASHINFER_VERSION} \ --extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \ diff --git a/docker/versions.json b/docker/versions.json index f4e05914afa0..fec5cba87f51 100644 --- a/docker/versions.json +++ b/docker/versions.json @@ -10,6 +10,9 @@ "UBUNTU_VERSION": { "default": "22.04" }, + "FLASHINFER_VERSION": { + "default": "0.6.8.post1" + }, "BUILD_BASE_IMAGE": { "default": "nvidia/cuda:13.0.0-devel-ubuntu22.04" }, @@ -64,9 +67,6 @@ "RUN_WHEEL_CHECK": { "default": "true" }, - "FLASHINFER_VERSION": { - "default": "0.6.8.post1" - }, "GDRCOPY_CUDA_VERSION": { "default": "12.8" }, diff --git a/setup.py b/setup.py index dee26159a270..8cefa64c0c08 100644 --- a/setup.py +++ b/setup.py @@ -967,14 +967,12 @@ def _read_requirements(filename: str) -> list[str]: # vllm-flash-attn is built only for CUDA 12.x. # Skip for other versions. continue - if req.startswith("nvidia-cutlass-dsl") and cuda_major == "13": - # On CUDA 13 builds, pull in the `[cu13]` extra — it ships - # the CuTe-DSL libs that FlashInfer's Blackwell SM100 GDN - # prefill kernel (flashinfer-ai/flashinfer#3001) JIT- - # requires. + if req.startswith("flashinfer-python") and cuda_major == "13": + # Activate FI's `[cu13]` extra on cu13 builds (cu13 deps for + # the SM100 GDN kernel). Mirrors the Dockerfile cu13 path. req = req.replace( - "nvidia-cutlass-dsl", - "nvidia-cutlass-dsl[cu13]", + "flashinfer-python", + "flashinfer-python[cu13]", 1, ) modified_requirements.append(req) From d79e7fef9c078ed67d87f0878636d0c565d45ef1 Mon Sep 17 00:00:00 2001 From: Artem Perevedentsev Date: Tue, 5 May 2026 14:50:55 +0300 Subject: [PATCH 5/5] Remove redundant cu13 and has_cutlass_dsl_cu13 checks Signed-off-by: Artem Perevedentsev --- .../layers/mamba/gdn_linear_attn.py | 12 ++---------- vllm/platforms/interface.py | 19 ------------------- 2 files changed, 2 insertions(+), 29 deletions(-) diff --git a/vllm/model_executor/layers/mamba/gdn_linear_attn.py b/vllm/model_executor/layers/mamba/gdn_linear_attn.py index c614f48b6a3c..13290b480a30 100644 --- a/vllm/model_executor/layers/mamba/gdn_linear_attn.py +++ b/vllm/model_executor/layers/mamba/gdn_linear_attn.py @@ -76,8 +76,7 @@ def _should_use_flashinfer_gdn_prefill(backend: str, head_k_dim: int | None) -> * ``platform == cuda``; * one of the following: - Hopper (SM90) — no further constraints; - - Blackwell (SM10.x) with ``head_k_dim == 128``, - ``nvidia-cutlass-dsl-libs-cu13`` installed, ``cuda_runtime >= 13``. + - Blackwell (SM10.x) with ``head_k_dim == 128``. """ if backend not in ["flashinfer", "auto"]: return False @@ -87,11 +86,7 @@ def _should_use_flashinfer_gdn_prefill(backend: str, head_k_dim: int | None) -> return True # Hopper — no further constraints. if not current_platform.is_device_capability_family(100): return False # Neither Hopper nor Blackwell. - if head_k_dim != 128: - return False - if current_platform.get_cuda_runtime_major() < 13: - return False - return current_platform.has_cutlass_dsl_cu13() + return head_k_dim == 128 def _log_gdn_backend_decision( @@ -102,18 +97,15 @@ def _log_gdn_backend_decision( platform = "cuda" if is_cuda else current_platform.device_name cuda_runtime = torch.version.cuda or "n/a" device_cap = str(current_platform.get_device_capability()) if is_cuda else "n/a" - cutlass_dsl_cu13_installed = current_platform.has_cutlass_dsl_cu13() logger.info_once( "GDN prefill backend inputs:\n" " requested=%s\n" " platform=%s, cuda_runtime=%s, device_capability=%s\n" - " nvidia_cutlass_dsl_libs_cu13_installed=%s\n" " head_k_dim=%s", backend, platform, cuda_runtime, device_cap, - cutlass_dsl_cu13_installed, head_k_dim, scope="local", ) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index cc41ec930b21..2753326755fb 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -362,25 +362,6 @@ def is_device_capability_family( return False return (current_capability.to_int() // 10) == (capability // 10) - @classmethod - def get_cuda_runtime_major(cls) -> int: - """Major ``torch.version.cuda`` version, or ``0`` if undetermined.""" - major = (torch.version.cuda or "0").split(".", 1)[0] - return int(major) if major.isdigit() else 0 - - @classmethod - def has_cutlass_dsl_cu13(cls) -> bool: - """Whether ``nvidia-cutlass-dsl-libs-cu13`` is installed.""" - try: - from importlib.metadata import distribution - except ImportError: - return False - try: - distribution("nvidia-cutlass-dsl-libs-cu13") - except Exception: - return False - return True - @classmethod def get_device_name(cls, device_id: int = 0) -> str: """Get the name of a device."""