@@ -151,6 +151,136 @@ void xpuAsyncMemcpy(
151151 }
152152}
153153
154+ // Infer which XPU device a USM device pointer was allocated on by probing
155+ // each device's SYCL context. Returns the device index on success.
156+ // This is O(num_xpu_devices) but avoids threading an explicit device argument
157+ // through the entire call chain when all callers already have the pointer.
158+ static at::DeviceIndex infer_xpu_device_from_ptr (const void * device_ptr) {
159+ const int n_devs = c10::xpu::device_count ();
160+ for (int i = 0 ; i < n_devs; i++) {
161+ auto ctx = vllm::xpu::vllmGetQueue (i).get_context ();
162+ auto type = sycl::get_pointer_type (device_ptr, ctx);
163+ if (type == sycl::usm::alloc::device || type == sycl::usm::alloc::shared) {
164+ return static_cast <at::DeviceIndex>(i);
165+ }
166+ }
167+ TORCH_CHECK (false , " Cannot determine XPU device from pointer" );
168+ return -1 ;
169+ }
170+
171+ void xpuAsyncMemcpyBatch (
172+ const uint64_t * src_ptrs,
173+ const uint64_t * dst_ptrs,
174+ const uint64_t * sizes,
175+ int64_t n) {
176+ if (n == 0 ) return ;
177+
178+ // Scan the first non-zero entry to determine copy direction.
179+ // Also capture the device-side pointer so we can infer which XPU to use.
180+ const void * device_probe = nullptr ;
181+ bool needs_staging = false ;
182+ bool dst_is_pageable = false ; // D2H to pageable host -> sync copy
183+ for (int64_t i = 0 ; i < n; i++) {
184+ if (sizes[i] == 0 ) continue ;
185+ const void * first_src = reinterpret_cast <const void *>(src_ptrs[i]);
186+ const void * first_dst = reinterpret_cast <const void *>(dst_ptrs[i]);
187+
188+ // Use device 0's context as a probe: we only need pointer *type* here,
189+ // and USM pointer types are consistent across all devices on the same
190+ // platform (host/unknown are always host; device is always device on its
191+ // own platform). The actual device index is resolved below via
192+ // infer_xpu_device_from_ptr().
193+ auto probe_ctx = vllm::xpu::vllmGetQueue (0 ).get_context ();
194+ auto src_type = sycl::get_pointer_type (first_src, probe_ctx);
195+ auto dst_type = sycl::get_pointer_type (first_dst, probe_ctx);
196+ bool src_is_host =
197+ (src_type == sycl::usm::alloc::host ||
198+ src_type == sycl::usm::alloc::unknown);
199+ bool dst_is_device = (dst_type == sycl::usm::alloc::device);
200+ needs_staging = src_is_host && dst_is_device;
201+ // D2H to pageable host requires synchronous copy to avoid corruption.
202+ dst_is_pageable = !dst_is_device && (dst_type == sycl::usm::alloc::unknown);
203+ // Device-side pointer: dst for H2D, src for D2H or D2D.
204+ device_probe = needs_staging ? first_dst : first_src;
205+ break ;
206+ }
207+
208+ if (device_probe == nullptr ) return ; // all sizes are zero
209+
210+ // Infer the target XPU device from the device pointer and set the guard so
211+ // that vllmGetQueue() returns the correct in-order queue.
212+ const at::DeviceIndex dev = infer_xpu_device_from_ptr (device_probe);
213+ const at::DeviceGuard device_guard (at::Device (at::kXPU , dev));
214+
215+ auto & queue = vllm::xpu::vllmGetQueue ();
216+
217+ // Compute total bytes needed for the H2D staging buffer.
218+ uint64_t total_bytes = 0 ;
219+ for (int64_t i = 0 ; i < n; i++) {
220+ total_bytes += sizes[i];
221+ }
222+
223+ if (needs_staging) {
224+ // H2D: allocate one contiguous pinned staging buffer, snapshot all source
225+ // blocks, then submit all async DMAs. This avoids N separate allocator
226+ // round-trips and protects against caller mutation after return.
227+ auto staging = at::getHostAllocator (at::kXPU )->allocate (
228+ static_cast <size_t >(total_bytes));
229+ char * staging_ptr = static_cast <char *>(staging.get ());
230+ TORCH_CHECK (staging_ptr, " Failed to allocate pinned staging buffer" );
231+
232+ // Phase 1: snapshot all source blocks into staging (pure CPU work).
233+ size_t staging_offset = 0 ;
234+ for (int64_t i = 0 ; i < n; i++) {
235+ size_t sz = static_cast <size_t >(sizes[i]);
236+ if (sz == 0 ) continue ;
237+ std::memcpy (
238+ staging_ptr + staging_offset,
239+ reinterpret_cast <const void *>(src_ptrs[i]),
240+ sz);
241+ staging_offset += sz;
242+ }
243+
244+ // Phase 2: submit async DMA from staging to device in a tight loop,
245+ // maximising PCIe/copy-engine throughput without interleaved CPU work.
246+ staging_offset = 0 ;
247+ for (int64_t i = 0 ; i < n; i++) {
248+ size_t sz = static_cast <size_t >(sizes[i]);
249+ if (sz == 0 ) continue ;
250+ queue.memcpy (
251+ reinterpret_cast <void *>(dst_ptrs[i]),
252+ staging_ptr + staging_offset,
253+ sz);
254+ staging_offset += sz;
255+ }
256+
257+ // Keep the staging buffer alive until all submitted DMAs complete.
258+ if (staging.get_context () != nullptr ) {
259+ at::getHostAllocator (at::kXPU )->record_event (
260+ staging_ptr,
261+ const_cast <void *>(staging.get_context ()),
262+ at::xpu::getCurrentXPUStream ());
263+ }
264+ } else {
265+ // D2H or D2D: dst_is_pageable was probed once from the first non-zero
266+ // entry (all entries share the same direction and memory class).
267+ // Pageable D2H is unsafe with async DMA; fall back to sync copy.
268+ for (int64_t i = 0 ; i < n; i++) {
269+ size_t sz = static_cast <size_t >(sizes[i]);
270+ if (sz == 0 ) continue ;
271+
272+ const void * src = reinterpret_cast <const void *>(src_ptrs[i]);
273+ void * dst = reinterpret_cast <void *>(dst_ptrs[i]);
274+
275+ if (dst_is_pageable) {
276+ queue.memcpy (dst, src, sz).wait ();
277+ } else {
278+ queue.memcpy (dst, src, sz);
279+ }
280+ }
281+ }
282+ }
283+
154284} // namespace xpu
155285} // namespace vllm
156286
0 commit comments