Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,15 @@ RUN --mount=type=cache,target=/root/.cache/uv \
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.'); \
fi

# On CUDA 13 builds, install `nvidia-cutlass-dsl[cu13]` — it ships the
Comment thread
arpera marked this conversation as resolved.
Outdated
# CuTe-DSL libs that FlashInfer's Blackwell SM100 GDN prefill kernel
# (flashinfer-ai/flashinfer#3001) JIT-requires.
RUN --mount=type=cache,target=/root/.cache/uv \
if [ "${CUDA_VERSION%%.*}" = "13" ]; then \
uv pip install --python /opt/venv/bin/python3 \
"nvidia-cutlass-dsl[cu13]>=4.4.2"; \
fi

# Track PyTorch lib versions used during build and match in downstream instances.
# We do this for both nightly and release so we can strip dependencies/*.txt as needed.
# Otherwise library dependencies can upgrade/downgrade torch incorrectly.
Expand Down Expand Up @@ -577,6 +586,14 @@ RUN --mount=type=cache,target=/root/.cache/uv \
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') && \
rm /tmp/requirements-cuda.txt /tmp/common.txt

# On CUDA 13 builds, install `nvidia-cutlass-dsl[cu13]` — it ships the
# CuTe-DSL libs that FlashInfer's Blackwell SM100 GDN prefill kernel
# (flashinfer-ai/flashinfer#3001) JIT-requires.
RUN --mount=type=cache,target=/root/.cache/uv \
if [ "${CUDA_VERSION%%.*}" = "13" ]; then \
uv pip install --system "nvidia-cutlass-dsl[cu13]>=4.4.2"; \
fi

# Install FlashInfer JIT cache (requires CUDA-version-specific index URL)
# https://docs.flashinfer.ai/installation.html
# From versions.json: .flashinfer.version
Expand Down
10 changes: 10 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,16 @@ def _read_requirements(filename: str) -> list[str]:
# vllm-flash-attn is built only for CUDA 12.x.
# Skip for other versions.
continue
if req.startswith("nvidia-cutlass-dsl") and cuda_major == "13":
# On CUDA 13 builds, pull in the `[cu13]` extra — it ships
# the CuTe-DSL libs that FlashInfer's Blackwell SM100 GDN
# prefill kernel (flashinfer-ai/flashinfer#3001) JIT-
# requires.
req = req.replace(
"nvidia-cutlass-dsl",
"nvidia-cutlass-dsl[cu13]",
1,
)
modified_requirements.append(req)
requirements = modified_requirements
elif _is_hip():
Expand Down
97 changes: 70 additions & 27 deletions vllm/model_executor/layers/mamba/gdn_linear_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,67 @@
logger = init_logger(__name__)


def _should_use_flashinfer_gdn_prefill(backend: str, head_k_dim: int | None) -> bool:
"""Whether to use FlashInfer's GDN prefill kernel instead of the
Triton/FLA fallback.

Requirements:
* ``requested in ["flashinfer", "auto"]``;
* ``platform == cuda``;
* one of the following:
- Hopper (SM90) — no further constraints;
- Blackwell (SM10.x) with ``head_k_dim == 128``,
``nvidia-cutlass-dsl-libs-cu13`` installed, ``cuda_runtime >= 13``.
"""
if backend not in ["flashinfer", "auto"]:
return False
if not current_platform.is_cuda():
return False
if current_platform.is_device_capability(90):
return True # Hopper — no further constraints.
if not current_platform.is_device_capability_family(100):
return False # Neither Hopper nor Blackwell.
if head_k_dim != 128:
return False
if current_platform.get_cuda_runtime_major() < 13:
return False
return current_platform.has_cutlass_dsl_cu13()


def _log_gdn_backend_decision(
backend: str, head_k_dim: int | None, use_flashinfer: bool
) -> None:
"""Dump the inputs to the backend decision and the final choice."""
is_cuda = current_platform.is_cuda()
platform = "cuda" if is_cuda else current_platform.device_name
cuda_runtime = torch.version.cuda or "n/a"
device_cap = str(current_platform.get_device_capability()) if is_cuda else "n/a"
cutlass_dsl_cu13_installed = current_platform.has_cutlass_dsl_cu13()
logger.info_once(
"GDN prefill backend inputs:\n"
" requested=%s\n"
" platform=%s, cuda_runtime=%s, device_capability=%s\n"
" nvidia_cutlass_dsl_libs_cu13_installed=%s\n"
" head_k_dim=%s",
backend,
platform,
cuda_runtime,
device_cap,
cutlass_dsl_cu13_installed,
head_k_dim,
scope="local",
)
if use_flashinfer:
logger.info_once("Using FlashInfer GDN prefill kernel")
logger.info_once(
"FlashInfer GDN prefill kernel is JIT-compiled; first run may "
"take a while to compile. Set `--gdn-prefill-backend triton` to "
"avoid JIT compile time.",
)
else:
logger.info_once("Using Triton/FLA GDN prefill kernel")


def fi_chunk_gated_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
Expand Down Expand Up @@ -118,39 +179,21 @@ def fi_chunk_gated_delta_rule(

@CustomOp.register("chunk_gated_delta_rule")
class ChunkGatedDeltaRule(CustomOp):
def __init__(self) -> None:
def __init__(self, head_k_dim: int | None = None) -> None:
super().__init__()
additional_config = get_current_vllm_config().additional_config
assert isinstance(additional_config, dict)
backend_cfg = additional_config.get("gdn_prefill_backend", "auto")
backend = str(backend_cfg).strip().lower()

supports_flashinfer = (
current_platform.is_cuda() and current_platform.is_device_capability(90)
)

if backend == "flashinfer":
use_flashinfer = supports_flashinfer
if not use_flashinfer:
logger.warning_once(
"GDN prefill backend 'flashinfer' is selected but "
"cannot use this kernel on the current platform. "
"Falling back to Triton/FLA."
)
elif backend == "triton":
use_flashinfer = False
else:
use_flashinfer = supports_flashinfer

if use_flashinfer:
logger.info_once("Using FlashInfer GDN prefill kernel")
logger.info_once(
"FlashInfer GDN prefill kernel is JIT-compiled; first run may "
"take a while to compile. Set `--gdn-prefill-backend triton` to "
"avoid JIT compile time.",
use_flashinfer = _should_use_flashinfer_gdn_prefill(backend, head_k_dim)
if backend == "flashinfer" and not use_flashinfer:
logger.warning_once(
"GDN prefill backend 'flashinfer' is selected but "
"cannot use this kernel on the current platform. "
"Falling back to Triton/FLA."
)
else:
logger.info_once("Using Triton/FLA GDN prefill kernel")
_log_gdn_backend_decision(backend, head_k_dim, use_flashinfer)
Comment thread
arpera marked this conversation as resolved.

self._forward_method = (
self.forward_cuda if use_flashinfer else self.forward_native
Expand Down Expand Up @@ -380,7 +423,7 @@ def __init__(
prefix=f"{prefix}.out_proj",
)

self.chunk_gated_delta_rule = ChunkGatedDeltaRule()
self.chunk_gated_delta_rule = ChunkGatedDeltaRule(head_k_dim=self.head_k_dim)
self.enable_packed_recurrent_decode = (
envs.VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE
)
Expand Down
19 changes: 19 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,25 @@ def is_device_capability_family(
return False
return (current_capability.to_int() // 10) == (capability // 10)

@classmethod
Comment thread
arpera marked this conversation as resolved.
Outdated
def get_cuda_runtime_major(cls) -> int:
"""Major ``torch.version.cuda`` version, or ``0`` if undetermined."""
major = (torch.version.cuda or "0").split(".", 1)[0]
return int(major) if major.isdigit() else 0

@classmethod
def has_cutlass_dsl_cu13(cls) -> bool:
Comment thread
arpera marked this conversation as resolved.
Outdated
"""Whether ``nvidia-cutlass-dsl-libs-cu13`` is installed."""
try:
from importlib.metadata import distribution
except ImportError:
return False
try:
distribution("nvidia-cutlass-dsl-libs-cu13")
except Exception:
return False
return True

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
"""Get the name of a device."""
Expand Down
Loading