-
-
Notifications
You must be signed in to change notification settings - Fork 16.4k
Use CU_MEMCPY_SRC_ACCESS_ORDER_ANY for batch KV cache swaps #39306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
4f51705
87e5072
2306690
156d180
7d5883f
d34e022
a46f2e7
7c209d4
51b05d9
a23dea7
543af59
1acf1a0
12d59d8
bbaea08
8e4d45d
1243a88
0c968ea
68aa841
82cc29c
9b8221f
3289183
5d91dbc
533ec29
78c23e5
1e7c07e
5255b65
aae1340
5a9a397
98f0f2c
b40f274
340be60
23075d0
f8f9d87
ecc6e93
a4022a3
d9a4249
0ea0a97
d8e87be
ab9be2e
c7d1e17
23f970e
08035a6
eafd48f
4e59324
e19d7c9
ae84b73
173a6d0
f31219c
a82bb5d
38825af
92ebf4d
2975168
909291b
2caed15
cc65b53
06881d0
2a6d081
394578c
813f0bf
fdaefe8
efd42bf
940801e
6336b8e
1046e5a
3a3b344
293964e
eff804c
819d1a5
20c4ea6
bf35538
1ba1667
4f97664
7649aca
ee0c260
7b079c1
576374e
c8fa5b0
36af9ee
5e9239b
521d5d7
d12dcbe
99ebd51
a87c9b1
1b27b7a
c86ea93
5c22ce1
17f8343
a3689a4
b3766e7
3888561
4451912
011c298
d6ff483
4084b0b
b6040cf
719fbd3
587bea5
1400acf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2782,6 +2782,7 @@ def swap_blocks_batch( | |
| src_ptrs: torch.Tensor, | ||
| dst_ptrs: torch.Tensor, | ||
| sizes: torch.Tensor, | ||
| src_access_order_any: bool = False, | ||
| ) -> None: | ||
| """ | ||
| Batch version of swap_blocks: submit all copies in a single driver call. | ||
|
|
@@ -2790,8 +2791,17 @@ def swap_blocks_batch( | |
| of sizes[i] bytes. All three tensors must be int64 CPU tensors. | ||
| On CUDA 12.8+ this uses cuMemcpyBatchAsync for minimal submission | ||
| overhead; on older CUDA it falls back to a loop of cudaMemcpyAsync. | ||
|
|
||
| src_access_order_any: if True, pass CU_MEMCPY_SRC_ACCESS_ORDER_ANY to | ||
| cuMemcpyBatchAsync, letting the DMA engine prefetch source bytes | ||
| out of stream order. Only safe when no GPU stream is concurrently | ||
| writing to the source (e.g. CPU->GPU, where the source is host | ||
| pinned memory). Defaults to False (STREAM ordering), which is | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the |
||
| always safe. | ||
| """ | ||
| torch.ops._C_cache_ops.swap_blocks_batch(src_ptrs, dst_ptrs, sizes) | ||
| torch.ops._C_cache_ops.swap_blocks_batch( | ||
| src_ptrs, dst_ptrs, sizes, src_access_order_any | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| ) | ||
|
|
||
|
|
||
| def convert_fp8( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rest is specific to the offloading connector implementation.