Skip to content

[Bugfix] Runtime driver check for cuMemcpyBatchAsync in swap_blocks_batch#38919

Merged
mgoin merged 12 commits intovllm-project:mainfrom
Etelis:fix/swap-blocks-batch-runtime-driver-check
Apr 11, 2026
Merged

[Bugfix] Runtime driver check for cuMemcpyBatchAsync in swap_blocks_batch#38919
mgoin merged 12 commits intovllm-project:mainfrom
Etelis:fix/swap-blocks-batch-runtime-driver-check

Conversation

@Etelis
Copy link
Copy Markdown
Contributor

@Etelis Etelis commented Apr 3, 2026

Fixes two issues introduced by swap_blocks_batch (#38460):

  1. undefined symbol: cuMemcpyBatchAsync on CUDA drivers < 12.8 (@JaheimLee) — pre-built wheels hard-link the symbol, crashing at import vllm._C time on older drivers.
  2. Compile error on CUDA 13.0 (@bbrowning, @eugr) — CUDA 13.0 headers #define cuMemcpyBatchAsync cuMemcpyBatchAsync_v2 (8 params), breaking the original 9-param call.

#38915 fixed problem 2 with compile-time #ifdef branching but left problem 1 open. This PR supersedes that approach by resolving the function at runtime via cuGetProcAddress("cuMemcpyBatchAsync", ..., 12080):

  • No direct symbol in the binary -> no crash on old drivers
  • String-based lookup -> immune to CUDA 13.0 #define remapping

…ks_batch

Replace the compile-time-only #ifdef guard for cuMemcpyBatchAsync with
a runtime resolution via cuGetProcAddress. Pre-built wheels compiled
with CUDA 12.8+ would fail with "undefined symbol: cuMemcpyBatchAsync"
on systems with older CUDA drivers (e.g. driver 12.1). The function
pointer is now resolved lazily and cached, falling back to individual
cudaMemcpyAsync calls when the driver lacks support.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@mergify mergify Bot added the bug Something isn't working label Apr 3, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the swap_blocks_batch function in csrc/cache_kernels.cu to resolve cuMemcpyBatchAsync at runtime using cuGetProcAddress. This change ensures that binaries compiled with CUDA 12.8+ remain compatible with older drivers by falling back to individual async copies if the batch function is unavailable. I have no feedback to provide as the implementation correctly handles the dynamic loading and fallback logic.

@johnnynunez
Copy link
Copy Markdown
Contributor

cc @mgoin

@eugr
Copy link
Copy Markdown

eugr commented Apr 3, 2026

Thanks, building with this PR now

@bbrowning
Copy link
Copy Markdown
Collaborator

After applying this patch on top of latest main I was able to build vLLM from source again with CUDA 13 on my DGX Spark. So, I'm that hopeful eugr will report success as well.

@eugr
Copy link
Copy Markdown

eugr commented Apr 3, 2026

The rebuild has been successful, the regression test pipeline is half way through now, so far so good

@Etelis
Copy link
Copy Markdown
Contributor Author

Etelis commented Apr 3, 2026

I will also rerun it.
But please do update.

@eugr
Copy link
Copy Markdown

eugr commented Apr 3, 2026

All checks passed, everything is good! Thanks for a quick turnaround!

@johnnynunez
Copy link
Copy Markdown
Contributor

Resolve the conflicts

@Etelis
Copy link
Copy Markdown
Contributor Author

Etelis commented Apr 4, 2026

@orozery cc

…ch-runtime-driver-check

# Conflicts:
#	csrc/cache_kernels.cu

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Copy link
Copy Markdown
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable to have the fallback, and thanks folks for confirming the fix

@mgoin mgoin added ready ONLY add when PR is ready to merge/full CI is needed nvidia labels Apr 4, 2026
@github-project-automation github-project-automation Bot moved this to Ready in NVIDIA Apr 4, 2026
@orozery
Copy link
Copy Markdown
Collaborator

orozery commented Apr 5, 2026

@Etelis How does this work for CUDA 13 if it expects 8 arguments instead of 9?

@Etelis
Copy link
Copy Markdown
Contributor Author

Etelis commented Apr 5, 2026

@Etelis How does this work for CUDA 13 if it expects 8 arguments instead of 9?

cuGetProcAddress with version 12080 always returns the 9-param function pointer, regardless of the driver version. CUDA drivers maintain all old function versions — the 13.0 change was only a header-level #define remapping to cuMemcpyBatchAsync_v2. Since we resolve by string name + explicit version at runtime, the header macro doesn't apply. The BatchFn typedef matches the v1 signature exactly.

Tested with both.

@Etelis
Copy link
Copy Markdown
Contributor Author

Etelis commented Apr 6, 2026

@mgoin Can we merge this?

@markmc
Copy link
Copy Markdown
Member

markmc commented Apr 8, 2026

Needed this on RHEL 9 with nvidia-driver-550.163.01 and CUDA 13, seems to work fine

@JaheimLee
Copy link
Copy Markdown

Hi, any update? @mgoin

@mgoin
Copy link
Copy Markdown
Member

mgoin commented Apr 11, 2026

Thanks for the ping!

@mgoin mgoin merged commit bd8bd52 into vllm-project:main Apr 11, 2026
142 checks passed
@github-project-automation github-project-automation Bot moved this from Ready to Done in NVIDIA Apr 11, 2026
wojciech-wais pushed a commit to wojciech-wais/vllm that referenced this pull request Apr 13, 2026
…atch (vllm-project#38919)

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
whk-lab pushed a commit to whk-lab/vllm that referenced this pull request Apr 23, 2026
…atch (vllm-project#38919)

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
avinashsingh77 pushed a commit to avinashsingh77/vllm that referenced this pull request Apr 27, 2026
…atch (vllm-project#38919)

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Avinash Singh <avinashsingh.rcoem@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working nvidia ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

9 participants