diff --git a/docker/Dockerfile b/docker/Dockerfile index fd0622e2416a..f1dcad898438 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -25,6 +25,7 @@ ARG CUDA_VERSION=13.0.2 ARG PYTHON_VERSION=3.12 ARG UBUNTU_VERSION=22.04 +ARG FLASHINFER_VERSION=0.6.8.post1 # By parameterizing the base images, we allow third-party to use their own # base images. One use case is hermetic builds with base images stored in @@ -101,6 +102,7 @@ FROM ${BUILD_BASE_IMAGE} AS base ARG CUDA_VERSION ARG PYTHON_VERSION +ARG FLASHINFER_VERSION ARG BUILD_OS ENV DEBIAN_FRONTEND=noninteractive @@ -212,6 +214,14 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.'); \ fi +# `flashinfer-python` is already installed via requirements/cuda.txt above; +# this only activates its `[cu13]` extra (cu13 deps for the SM100 GDN kernel). +RUN --mount=type=cache,target=/root/.cache/uv \ + if [ "${CUDA_VERSION%%.*}" = "13" ]; then \ + uv pip install --python /opt/venv/bin/python3 \ + "flashinfer-python[cu13]==${FLASHINFER_VERSION}"; \ + fi + # Track PyTorch lib versions used during build and match in downstream instances. # We do this for both nightly and release so we can strip dependencies/*.txt as needed. # Otherwise library dependencies can upgrade/downgrade torch incorrectly. @@ -522,6 +532,7 @@ ARG PYTHON_VERSION ARG DEADSNAKES_MIRROR_URL ARG DEADSNAKES_GPGKEY_URL ARG GET_PIP_URL +ARG FLASHINFER_VERSION ENV DEBIAN_FRONTEND=noninteractive WORKDIR /vllm-workspace @@ -620,10 +631,17 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') && \ rm /tmp/requirements-cuda.txt /tmp/common.txt +# `flashinfer-python` is already installed via requirements/cuda.txt above; +# this only activates its `[cu13]` extra (cu13 deps for the SM100 GDN kernel). +RUN --mount=type=cache,target=/root/.cache/uv \ + if [ "${CUDA_VERSION%%.*}" = "13" ]; then \ + uv pip install --system \ + "flashinfer-python[cu13]==${FLASHINFER_VERSION}"; \ + fi + # Install FlashInfer JIT cache (requires CUDA-version-specific index URL) # https://docs.flashinfer.ai/installation.html # From versions.json: .flashinfer.version -ARG FLASHINFER_VERSION=0.6.8.post1 RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system flashinfer-jit-cache==${FLASHINFER_VERSION} \ --extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') diff --git a/docker/versions.json b/docker/versions.json index 75652823db0b..172536df54d2 100644 --- a/docker/versions.json +++ b/docker/versions.json @@ -10,6 +10,9 @@ "UBUNTU_VERSION": { "default": "22.04" }, + "FLASHINFER_VERSION": { + "default": "0.6.8.post1" + }, "BUILD_BASE_IMAGE": { "default": "nvidia/cuda:13.0.2-devel-ubuntu22.04" }, @@ -67,9 +70,6 @@ "RUN_WHEEL_CHECK": { "default": "true" }, - "FLASHINFER_VERSION": { - "default": "0.6.8.post1" - }, "GDRCOPY_CUDA_VERSION": { "default": "12.8" }, diff --git a/setup.py b/setup.py index 7c226a72425f..4d790d289cf2 100644 --- a/setup.py +++ b/setup.py @@ -969,6 +969,14 @@ def _read_requirements(filename: str) -> list[str]: # vllm-flash-attn is built only for CUDA 12.x. # Skip for other versions. continue + if req.startswith("flashinfer-python") and cuda_major == "13": + # Activate FI's `[cu13]` extra on cu13 builds (cu13 deps for + # the SM100 GDN kernel). Mirrors the Dockerfile cu13 path. + req = req.replace( + "flashinfer-python", + "flashinfer-python[cu13]", + 1, + ) modified_requirements.append(req) requirements = modified_requirements elif _is_hip(): diff --git a/vllm/model_executor/layers/mamba/gdn_linear_attn.py b/vllm/model_executor/layers/mamba/gdn_linear_attn.py index a621ab962f0a..13290b480a30 100644 --- a/vllm/model_executor/layers/mamba/gdn_linear_attn.py +++ b/vllm/model_executor/layers/mamba/gdn_linear_attn.py @@ -67,6 +67,59 @@ logger = init_logger(__name__) +def _should_use_flashinfer_gdn_prefill(backend: str, head_k_dim: int | None) -> bool: + """Whether to use FlashInfer's GDN prefill kernel instead of the + Triton/FLA fallback. + + Requirements: + * ``requested in ["flashinfer", "auto"]``; + * ``platform == cuda``; + * one of the following: + - Hopper (SM90) — no further constraints; + - Blackwell (SM10.x) with ``head_k_dim == 128``. + """ + if backend not in ["flashinfer", "auto"]: + return False + if not current_platform.is_cuda(): + return False + if current_platform.is_device_capability(90): + return True # Hopper — no further constraints. + if not current_platform.is_device_capability_family(100): + return False # Neither Hopper nor Blackwell. + return head_k_dim == 128 + + +def _log_gdn_backend_decision( + backend: str, head_k_dim: int | None, use_flashinfer: bool +) -> None: + """Dump the inputs to the backend decision and the final choice.""" + is_cuda = current_platform.is_cuda() + platform = "cuda" if is_cuda else current_platform.device_name + cuda_runtime = torch.version.cuda or "n/a" + device_cap = str(current_platform.get_device_capability()) if is_cuda else "n/a" + logger.info_once( + "GDN prefill backend inputs:\n" + " requested=%s\n" + " platform=%s, cuda_runtime=%s, device_capability=%s\n" + " head_k_dim=%s", + backend, + platform, + cuda_runtime, + device_cap, + head_k_dim, + scope="local", + ) + if use_flashinfer: + logger.info_once("Using FlashInfer GDN prefill kernel") + logger.info_once( + "FlashInfer GDN prefill kernel is JIT-compiled; first run may " + "take a while to compile. Set `--gdn-prefill-backend triton` to " + "avoid JIT compile time.", + ) + else: + logger.info_once("Using Triton/FLA GDN prefill kernel") + + def fi_chunk_gated_delta_rule( q: torch.Tensor, k: torch.Tensor, @@ -118,39 +171,21 @@ def fi_chunk_gated_delta_rule( @CustomOp.register("chunk_gated_delta_rule") class ChunkGatedDeltaRule(CustomOp): - def __init__(self) -> None: + def __init__(self, head_k_dim: int | None = None) -> None: super().__init__() additional_config = get_current_vllm_config().additional_config assert isinstance(additional_config, dict) backend_cfg = additional_config.get("gdn_prefill_backend", "auto") backend = str(backend_cfg).strip().lower() - supports_flashinfer = ( - current_platform.is_cuda() and current_platform.is_device_capability(90) - ) - - if backend == "flashinfer": - use_flashinfer = supports_flashinfer - if not use_flashinfer: - logger.warning_once( - "GDN prefill backend 'flashinfer' is selected but " - "cannot use this kernel on the current platform. " - "Falling back to Triton/FLA." - ) - elif backend == "triton": - use_flashinfer = False - else: - use_flashinfer = supports_flashinfer - - if use_flashinfer: - logger.info_once("Using FlashInfer GDN prefill kernel") - logger.info_once( - "FlashInfer GDN prefill kernel is JIT-compiled; first run may " - "take a while to compile. Set `--gdn-prefill-backend triton` to " - "avoid JIT compile time.", + use_flashinfer = _should_use_flashinfer_gdn_prefill(backend, head_k_dim) + if backend == "flashinfer" and not use_flashinfer: + logger.warning_once( + "GDN prefill backend 'flashinfer' is selected but " + "cannot use this kernel on the current platform. " + "Falling back to Triton/FLA." ) - else: - logger.info_once("Using Triton/FLA GDN prefill kernel") + _log_gdn_backend_decision(backend, head_k_dim, use_flashinfer) self._forward_method = ( self.forward_cuda if use_flashinfer else self.forward_native @@ -380,7 +415,7 @@ def __init__( prefix=f"{prefix}.out_proj", ) - self.chunk_gated_delta_rule = ChunkGatedDeltaRule() + self.chunk_gated_delta_rule = ChunkGatedDeltaRule(head_k_dim=self.head_k_dim) self.enable_packed_recurrent_decode = ( envs.VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE )