Skip to content

[Attention] Add head_dim=512 support for FlashInfer trtllm attention backend#38822

Open
djmmoss wants to merge 3 commits intovllm-project:mainfrom
djmmoss:dmoss/trtllm-fmha-head-dim-512
Open

[Attention] Add head_dim=512 support for FlashInfer trtllm attention backend#38822
djmmoss wants to merge 3 commits intovllm-project:mainfrom
djmmoss:dmoss/trtllm-fmha-head-dim-512

Conversation

@djmmoss
Copy link
Copy Markdown
Contributor

@djmmoss djmmoss commented Apr 2, 2026

Add 512 to the FlashInfer backend's supported head sizes, enabling models with head_dim=512 attention layers to use the FlashInfer trtllm attention kernels on Blackwell GPUs.

This companion PR enables the head_dim=512 cubin support in FlashInfer: flashinfer-ai/flashinfer#2959

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for a head size of 512 to the FlashInfer attention backend. Feedback suggests that this specific head size should be restricted to Blackwell GPUs (SM100+) by implementing a supports_combination check, as older architectures do not natively support this dimension and may experience runtime crashes.

Comment on lines 388 to +390
def get_supported_head_sizes(cls) -> list[int]:
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
return [64, 128, 256]
return [64, 128, 256, 512]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Adding 512 to the supported head sizes without hardware-specific validation can lead to runtime crashes on non-Blackwell GPUs. The PR description states that head_dim=512 is intended for the TRTLLM attention kernels on Blackwell (SM100+). However, FlashInferBackend is also used on earlier architectures (SM75+) where the native FlashInfer kernels do not support this head dimension.

To prevent invalid configurations from reaching the execution stage, you should override supports_combination to ensure head_dim=512 is only permitted when the device capability is SM100 or higher.

Suggested change
def get_supported_head_sizes(cls) -> list[int]:
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
return [64, 128, 256]
return [64, 128, 256, 512]
def get_supported_head_sizes(cls) -> list[int]:
# FlashInfer native kernels support 64, 128, 256.
# 512 is supported via TRTLLM kernels on Blackwell.
return [64, 128, 256, 512]
@classmethod
def supports_combination(
cls,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
device_capability: DeviceCapability,
) -> str | None:
if head_size == 512 and device_capability.major < 10:
return "head_dim=512 is only supported on Blackwell GPUs (SM100+)"
return None

@djmmoss
Copy link
Copy Markdown
Contributor Author

djmmoss commented Apr 2, 2026

The concern about runtime crashes on non-Blackwell GPUs is already handled by vLLM's existing backend selection and validation system:

  1. Backend priority routing (cuda.py): On pre-SM100 GPUs, FLASH_ATTN is prioritized over FLASHINFER, and FLASH_ATTN already supports head_dim=512. FlashInfer would only be selected on SM100+ where the trtllm cubins are available.

  2. Cubin-level validation: Even if FlashInfer were explicitly forced on an older GPU, the trtllm cubin loader validates kernel availability at initialization and raises a clear error — no silent crash.

  3. Precedent: Other backends already list 512 without architecture gating (e.g., cpu_attn returns [32, 64, 80, 96, 112, 128, 160, 192, 224, 256, 512]), and flex_attention returns [] (all sizes accepted).

get_supported_head_sizes() declares what the backend can support across all its code paths, not what a specific GPU supports. The architecture-specific filtering happens in supports_compute_capability() and the backend priority logic.

Add 512 to the list of supported head sizes in the FlashInfer
attention backend. This enables models with head_dim=512 (used
in global/full attention layers) to use the FlashInfer trtllm-gen
attention kernels on Blackwell GPUs.

The trtllm-gen cubins for head_dim=512 are available in the
FlashInfer cubin repository for SM100f (BF16, FP8, and FP16
dtypes).

Co-authored-by: Claude
Signed-off-by: Daniel Moss <dmoss@nvidia.com>
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
@djmmoss djmmoss force-pushed the dmoss/trtllm-fmha-head-dim-512 branch from f2e851e to 038ab75 Compare April 4, 2026 00:21
@djmmoss djmmoss marked this pull request as ready for review April 4, 2026 14:27
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 4, 2026

Documentation preview: https://vllm--38822.org.readthedocs.build/en/38822/

Signed-off-by: Duncan Moss <djm.moss@gmail.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 30, 2026

Hi @djmmoss, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@vadiklyutiy
Copy link
Copy Markdown
Collaborator

may we add unit test that cover 512?

@djmmoss
Copy link
Copy Markdown
Contributor Author

djmmoss commented May 6, 2026

waiting for: #41711

| may we add unit test that cover 512?
@vadiklyutiy this is covered in the flashinfer repo

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation nvidia v1

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

2 participants