Skip to content

Commit ccd66dd

Browse files
committed
feat: prefetch offloader weights using batched memcpy async
1 parent a3ec4a3 commit ccd66dd

1 file changed

Lines changed: 84 additions & 17 deletions

File tree

vllm/model_executor/offloader/prefetch.py

Lines changed: 84 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import torch
1818
import torch.nn as nn
19+
import vllm._custom_ops as ops
1920

2021
# Import prefetch_ops to register custom ops at module load time
2122
import vllm.model_executor.offloader.prefetch_ops # noqa: F401
@@ -390,6 +391,10 @@ def __init__(
390391
# Used for per-layer synchronization (both eager and capture modes).
391392
self._copy_done_event = torch.cuda.Event()
392393

394+
# Fork: record event on compute stream, copy_stream waits on it.
395+
# This joins copy_stream to any active CUDA graph capture.
396+
self._fork_event = torch.cuda.Event()
397+
393398
# Track whether _copy_done_event is valid for eager-mode wait_event.
394399
# False when: (1) never recorded, or (2) last recorded during a
395400
# cudagraph capture (events become invalid after capture ends).
@@ -409,6 +414,12 @@ def __init__(
409414
self._buffer_pool: StaticBufferPool | None = None
410415
self._buffer_slot_idx: int = 0
411416

417+
# Buffer pointers
418+
# Grouped pointers enable batch copy from cuMemcpyBatchAsync.
419+
self._buffer_src_ptrs: torch.Tensor | None = None
420+
self._buffer_dst_ptrs: torch.Tensor | None = None
421+
self._buffer_sizes: torch.Tensor | None = None
422+
412423
param_dict = dict(self.module.named_parameters())
413424
assert all(name in param_dict for name in whitelist_param_names), (
414425
f"Whitelist params {whitelist_param_names} not found in module params "
@@ -485,6 +496,12 @@ def assign_buffer_slot(self, pool: StaticBufferPool, slot_idx: int):
485496
"""
486497
self._buffer_pool = pool
487498
self._buffer_slot_idx = slot_idx
499+
500+
pin_memory = should_pin_memory()
501+
502+
src_ptrs: list[int] = []
503+
dst_ptrs: list[int] = []
504+
sizes: list[int] = []
488505

489506
# Assign static buffers to parameters
490507
# Use CPU storage shape/stride/dtype since param.data is now empty
@@ -500,6 +517,37 @@ def assign_buffer_slot(self, pool: StaticBufferPool, slot_idx: int):
500517
)
501518
offloader.assign_static_buffer(buffer)
502519

520+
# IMPORTANT: Update pointer.
521+
cpu_storage = offloader._cpu_storage
522+
assert cpu_storage is not None, "CPU storage not initialized"
523+
assert not pin_memory or cpu_storage.is_pinned(), (
524+
f"CPU storage for {name} is not pinned, but pin_memory is "
525+
"enabled. The batched H2D prefetch path requires pinned "
526+
"source memory; otherwise cuMemcpyBatchAsync degrades to a "
527+
"synchronous copy and breaks event-based fork "
528+
"synchronization with the compute stream."
529+
)
530+
531+
src_ptrs.append(cpu_storage.data_ptr())
532+
dst_ptrs.append(buffer.data_ptr())
533+
sizes.append(cpu_storage.numel() * cpu_storage.element_size())
534+
535+
# Group buffer's pointer.
536+
if not src_ptrs:
537+
self._buffer_src_ptrs = None
538+
self._buffer_dst_ptrs = None
539+
self._buffer_sizes = None
540+
else:
541+
self._buffer_src_ptrs = torch.tensor(
542+
src_ptrs, dtype=torch.int64, pin_memory=pin_memory
543+
)
544+
self._buffer_dst_ptrs = torch.tensor(
545+
dst_ptrs, dtype=torch.int64, pin_memory=pin_memory
546+
)
547+
self._buffer_sizes = torch.tensor(
548+
sizes, dtype=torch.int64, pin_memory=pin_memory
549+
)
550+
503551
def start_onload_to_static(self):
504552
"""Start async copy from CPU storage to GPU buffer.
505553
@@ -514,33 +562,52 @@ def start_onload_to_static(self):
514562
assert self._buffer_pool is not None, "Buffer pool not assigned"
515563

516564
# Track if this prefetch is being captured (for _wait_for_layer logic)
517-
self._prefetch_in_capture = torch.cuda.is_current_stream_capturing()
565+
in_capture = torch.cuda.is_current_stream_capturing()
566+
self._prefetch_in_capture = in_capture
518567

519568
# Fork: record event on compute stream, copy_stream waits on it
520-
# This joins copy_stream to any active CUDA graph capture
521-
fork_event = torch.cuda.Event()
522-
torch.cuda.current_stream().record_event(fork_event)
523-
self.copy_stream.wait_event(fork_event)
569+
# This joins copy_stream to any active CUDA graph capture.
570+
torch.cuda.current_stream().record_event(self._fork_event)
571+
self.copy_stream.wait_event(self._fork_event)
524572

525573
with torch.cuda.stream(self.copy_stream):
526-
for name, offloader in self._param_offloaders.items():
527-
cpu_storage = offloader._cpu_storage
528-
gpu_buffer = offloader._gpu_buffer
529-
assert cpu_storage is not None, "CPU storage not initialized"
530-
assert gpu_buffer is not None, "GPU buffer not assigned"
531-
assert not should_pin_memory() or cpu_storage.is_pinned(), (
532-
f"CPU storage for {name} is not pinned! "
533-
"non_blocking=True H2D copy from non-pinned memory "
534-
"causes stream synchronization that breaks "
535-
"event-based fork synchronization."
574+
if in_capture:
575+
# cuMemcpyBatchAsync is not capture-safe.
576+
# Slow path: Fallbacks to per-param copy_() so they can get recorded into the graph.
577+
for name, offloader in self._param_offloaders.items():
578+
cpu_storage = offloader._cpu_storage
579+
gpu_buffer = offloader._gpu_buffer
580+
assert cpu_storage is not None, "CPU storage not initialized"
581+
assert gpu_buffer is not None, "GPU buffer not assigned"
582+
assert not should_pin_memory() or cpu_storage.is_pinned(), (
583+
f"CPU storage for {name} is not pinned! "
584+
"non_blocking=True H2D copy from non-pinned memory "
585+
"causes stream synchronization that breaks "
586+
"event-based fork synchronization."
587+
)
588+
gpu_buffer.copy_(cpu_storage, non_blocking=True)
589+
elif (
590+
self._buffer_src_ptrs is not None
591+
and self._buffer_dst_ptrs is not None
592+
and self._buffer_sizes is not None
593+
):
594+
# Fast path: batched copy using custom op (single cuMemcpyBatchAsync call on CUDA 12.8+)
595+
# cuMemcpyBatchAsync can have less driver-call overhead and better performance.
596+
# swap_blocks_batch() will fallback to per-param copy_() if cuMemcpyBatchAsync is not available.
597+
ops.swap_blocks_batch(
598+
src_ptrs=self._buffer_src_ptrs,
599+
dst_ptrs=self._buffer_dst_ptrs,
600+
sizes=self._buffer_sizes
536601
)
537-
gpu_buffer.copy_(cpu_storage, non_blocking=True)
602+
else:
603+
# No params to copy (shouldn't normally happen).
604+
pass
538605

539606
# Record completion event for _wait_for_layer to use
540607
self._copy_done_event.record(self.copy_stream)
541608
# Event is only valid for eager wait_event if recorded outside capture.
542609
# Events recorded during capture become invalid after capture ends.
543-
self._event_valid_for_eager = not torch.cuda.is_current_stream_capturing()
610+
self._event_valid_for_eager = not in_capture
544611

545612

546613
class _BaseParamOffloader(ABC):

0 commit comments

Comments
 (0)