Skip to content

Use CU_MEMCPY_SRC_ACCESS_ORDER_ANY for batch KV cache swaps#39306

Open
Etelis wants to merge 98 commits intovllm-project:mainfrom
Etelis:try-memcpy-access-order-any
Open

Use CU_MEMCPY_SRC_ACCESS_ORDER_ANY for batch KV cache swaps#39306
Etelis wants to merge 98 commits intovllm-project:mainfrom
Etelis:try-memcpy-access-order-any

Conversation

@Etelis
Copy link
Copy Markdown
Contributor

@Etelis Etelis commented Apr 8, 2026

Use CU_MEMCPY_SRC_ACCESS_ORDER_ANY instead of CU_MEMCPY_SRC_ACCESS_ORDER_STREAM in the cuMemcpyBatchAsync call used for batched KV cache swap copies.

This relaxes the source access ordering constraint, allowing the CUDA driver to pipeline reads more aggressively. The safety of this change relies on the fact that source data is always fully written before the batch copy begins — the offloading handler synchronizes via stream events (stream.wait_event(last_event)) before issuing the copy.

Motivated by this code review comment from @ivanium on PR #38460, who observed improved CPU->GPU bandwidth on Grace Blackwell nodes with ACCESS_ORDER_ANY.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the CUDA memory copy attributes in csrc/cache_kernels.cu by changing the source access order from CU_MEMCPY_SRC_ACCESS_ORDER_STREAM to CU_MEMCPY_SRC_ACCESS_ORDER_ANY. I have no further feedback to provide.

@ivanium
Copy link
Copy Markdown
Contributor

ivanium commented Apr 8, 2026

Thanks for the work. Can you add some benchmark results?

@Etelis
Copy link
Copy Markdown
Contributor Author

Etelis commented Apr 9, 2026

Thanks for the work. Can you add some benchmark results?

NVIDIA GH200 480GB (Grace Hopper, NVLink-C2C, sm_90)
Driver: 580.105.08, CUDA Toolkit 12.8

Built vllm from source twice — once with STREAM, once with ANY in csrc/cache_kernels.cu
Called vllm._custom_ops.swap_blocks_batch() directly with pinned host memory <-> device memory

CUDA events (GPU-side), 200 iterations, 20 warmup, non-default stream
Varying block counts (64-4096) and sizes (4KB-256KB), both CPU->GPU and GPU->CPU

GPU Copy Time (CUDA Events)

Dir Blocks BlkSize STREAM med ms ANY med ms Delta
c->g 64 4KB 0.0238 0.0229 -3.8%
g->c 64 4KB 0.0229 0.0217 -5.2%
c->g 256 4KB 0.0729 0.0713 -2.2%
g->c 256 4KB 0.0684 0.0665 -2.8%
c->g 1024 4KB 0.2760 0.2689 -2.6%
g->c 1024 4KB 0.2569 0.2502 -2.6%
c->g 4096 4KB 1.0701 1.0503 -1.9%
g->c 4096 4KB 0.9783 0.9557 -2.3%
c->g 64 32KB 0.0256 0.0246 -3.9%
g->c 64 32KB 0.0237 0.0224 -5.5%
c->g 256 32KB 0.0760 0.0736 -3.2%
g->c 256 32KB 0.0706 0.0683 -3.3%
c->g 32 256KB 0.0292 0.0293 +0.3%
g->c 32 256KB 0.0310 0.0314 +1.3%
c->g 64 128KB 0.0294 0.0294 0.0%
g->c 64 128KB 0.0313 0.0313 0.0%

Host Submission Time (no sync)

Dir Blocks BlkSize STREAM med us ANY med us Delta
c->g 1024 4KB 270.2 263.1 -2.6%
g->c 1024 4KB 251.3 244.8 -2.6%
c->g 4096 4KB 1061.5 1043.5 -1.7%
g->c 4096 4KB 971.9 948.6 -2.4%

Didn't get a Grace Blackwell, The closest I can put my hands on is a Grace Hopper.
@ivanium you might be able to test it yourself as well..

@Etelis
Copy link
Copy Markdown
Contributor Author

Etelis commented Apr 9, 2026

cc @orozery , @mgoin

@orozery
Copy link
Copy Markdown
Collaborator

orozery commented Apr 9, 2026

@Etelis I think we want to make this a parameter.
For GPU->CPU I believe we still want stream order.

@Etelis
Copy link
Copy Markdown
Contributor Author

Etelis commented Apr 9, 2026

@Etelis I think we want to make this a parameter. For GPU->CPU I believe we still want stream order.

The only place we're using swap_blocks_batch today is in the cpu_gpu.py where we already have stream.wait_event(last_event) in place (for the GPU->CPU)

So that would be as an extra caution measurement?

@orozery
Copy link
Copy Markdown
Collaborator

orozery commented Apr 9, 2026

The only place we're using swap_blocks_batch today is in the cpu_gpu.py where we already have stream.wait_event(last_event) in place (for the GPU->CPU)

So that would be as an extra caution measurement?

I'm not sure.
Can you then give me an example where ORDER_ANY behaves differently than ORDER_STREAM?

@Etelis
Copy link
Copy Markdown
Contributor Author

Etelis commented Apr 9, 2026

The only place we're using swap_blocks_batch today is in the cpu_gpu.py where we already have stream.wait_event(last_event) in place (for the GPU->CPU)
So that would be as an extra caution measurement?

I'm not sure. Can you then give me an example where ORDER_ANY behaves differently than ORDER_STREAM?

The only place we're using swap_blocks_batch today is in the cpu_gpu.py where we already have stream.wait_event(last_event) in place (for the GPU->CPU)
So that would be as an extra caution measurement?

I'm not sure. Can you then give me an example where ORDER_ANY behaves differently than ORDER_STREAM?

From CUDA docs:

If the source access order is set to CU_MEMCPY_SRC_ACCESS_ORDER_STREAM, then the source will be accessed in stream order. ... If the source access order is set to CU_MEMCPY_SRC_ACCESS_ORDER_ANY then it indicates that access to the source pointer can be out of stream order and the accesses can happen even after the API call returns.

So if I understand correctly, since we're creating a new stream for the offloading operations, CU_MEMCPY_SRC_ACCESS_ORDER_STREAM will wait on any operation on the copy stream prior to the DMA batch before it starts reading source memory.

CU_MEMCPY_SRC_ACCESS_ORDER_ANY skips that — it lets the DMA engine start reading immediately without checking for prior work on the stream. But since we already do stream.wait_stream(torch.cuda.current_stream()) on the compute stream before GPU->CPU copies, there shouldn't be any case where we're still writing to those GPU blocks. The copy stream itself never writes to the source buffers — it only runs DMA copies and sync barriers.

So in this case, CU_MEMCPY_SRC_ACCESS_ORDER_STREAM is an extra caution step.
Am I missing something?

@orozery
Copy link
Copy Markdown
Collaborator

orozery commented Apr 9, 2026

What is it reads at call time, and writes at "stream time"?
I think this is the difference between STREAM and ANY.
Both guarantee the write will happen in stream order, but the read operation can be out-of-stream if using ANY.

@Etelis
Copy link
Copy Markdown
Contributor Author

Etelis commented Apr 9, 2026

What is it reads at call time, and writes at "stream time"? I think this is the difference between STREAM and ANY. Both guarantee the write will happen in stream order, but the read operation can be out-of-stream if using ANY.

I see what you mean,
I'm speculating here, but I think that overhead is where the 2-5% improvement comes from on small blocks.
The driver likely takes a lighter codepath when it knows it doesn't need to enforce ordering at all.

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 11, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Etelis.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 11, 2026
Relax source access ordering in cuMemcpyBatchAsync from STREAM to ANY.
The source data is always fully written before copies start (ensured by
stream event synchronization), so strict stream ordering is redundant.
ANY gives the driver freedom to pipeline reads more aggressively, which
may improve CPU<->GPU bandwidth on NVLink-C2C interconnects (Grace
Hopper/Blackwell).

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@Etelis Etelis force-pushed the try-memcpy-access-order-any branch from 1df9e61 to 4f51705 Compare April 12, 2026 16:48
@mergify mergify Bot removed the needs-rebase label Apr 12, 2026
@Etelis
Copy link
Copy Markdown
Contributor Author

Etelis commented Apr 20, 2026

@orozery you were right, thanks for the pushback.

The subtlety is that CU_MEMCPY_SRC_ACCESS_ORDER_ANY only relaxes the source-read ordering — the destination write is still stream-ordered. So our stream.wait_stream(compute) and stream.wait_event(last_event) gate when the copy's write commits, but under ANY the DMA engine is free to prefetch the source bytes before those barriers fire. For GPU→CPU that means the DMA can start reading the KV cache before compute has finished writing it.

CPU→GPU doesn't have this problem.

GPU->CPU source is the live GPU KV cache, which the compute stream
keeps writing; ANY would let the DMA prefetch source bytes before
wait_stream(compute) fires. Keep STREAM there; only the CPU->GPU
handler (host pinned source) opts into ANY.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@Etelis Etelis requested review from ApostaC and orozery as code owners April 20, 2026 15:33
@mergify mergify Bot added the v1 label Apr 20, 2026
Comment thread csrc/cache_kernels.cu Outdated
Comment on lines +129 to +131
// source (e.g. CPU->GPU reads from host-owned pinned memory). For
// GPU->CPU we must keep STREAM so source reads are gated by the
// transfer stream's wait_stream(compute) / wait_event barriers.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rest is specific to the offloading connector implementation.

Suggested change
// source (e.g. CPU->GPU reads from host-owned pinned memory). For
// GPU->CPU we must keep STREAM so source reads are gated by the
// transfer stream's wait_stream(compute) / wait_event barriers.
// source.

Comment thread vllm/_custom_ops.py Outdated
Comment on lines +2798 to +2799
writing to the source (e.g. CPU->GPU, where the source is host
pinned memory). Defaults to False (STREAM ordering), which is
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the e.g. part is specific to the offloading connector implementation. Let's remove it.

Comment thread vllm/_custom_ops.py Outdated
"""
torch.ops._C_cache_ops.swap_blocks_batch(src_ptrs, dst_ptrs, sizes)
torch.ops._C_cache_ops.swap_blocks_batch(
src_ptrs, dst_ptrs, sizes, src_access_order_any
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

src_access_order_any is a big confusing.
Maybe rename it to is_src_access_order_any?

Rename src_access_order_any -> is_src_access_order_any and keep the
op-level comment/docstring generic; the offloader-specific rationale
stays at the cpu_gpu.py call site where it belongs.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 23, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Etelis.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 23, 2026
@Etelis
Copy link
Copy Markdown
Contributor Author

Etelis commented Apr 23, 2026

@orozery Addressed all three — op-level comment/docstring are now generic, is_src_access_order_any across all layers. Offloader-specific rationale stays at the cpu_gpu.py call site where it belongs."

…rder-any

# Conflicts:
#	vllm/v1/kv_offload/worker/cpu_gpu.py

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@mergify mergify Bot removed the needs-rebase label Apr 23, 2026
Copy link
Copy Markdown
Collaborator

@orozery orozery left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Etelis !

EtelisIBM added 21 commits May 1, 2026 23:06
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 4, 2026

Hi @Etelis, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants