@@ -168,6 +168,128 @@ void swap_blocks(torch::Tensor &x, torch::Tensor &y, const torch::Tensor &z)
168168 return ;
169169}
170170
171+ void swap_blocks_batch (const torch::Tensor& src_ptrs,
172+ const torch::Tensor& dst_ptrs,
173+ const torch::Tensor& sizes,
174+ int64_t direction) {
175+
176+ TORCH_CHECK (src_ptrs.device ().is_cpu (), " src_ptrs must be on CPU" );
177+ TORCH_CHECK (dst_ptrs.device ().is_cpu (), " dst_ptrs must be on CPU" );
178+ TORCH_CHECK (sizes.device ().is_cpu (), " sizes must be on CPU" );
179+ TORCH_CHECK (src_ptrs.dtype () == torch::kInt64 , " src_ptrs must be int64" );
180+ TORCH_CHECK (dst_ptrs.dtype () == torch::kInt64 , " dst_ptrs must be int64" );
181+ TORCH_CHECK (sizes.dtype () == torch::kInt64 , " sizes must be int64" );
182+
183+ const int64_t n = src_ptrs.size (0 );
184+ TORCH_CHECK (dst_ptrs.size (0 ) == n, " dst_ptrs length must match src_ptrs" );
185+ TORCH_CHECK (sizes.size (0 ) == n, " sizes length must match src_ptrs" );
186+
187+ if (n == 0 ) return ;
188+
189+ const int64_t * src_data = src_ptrs.data_ptr <int64_t >();
190+ const int64_t * dst_data = dst_ptrs.data_ptr <int64_t >();
191+ const int64_t * size_data = sizes.data_ptr <int64_t >();
192+
193+ aclrtStream stream = c10_npu::getCurrentNPUStream ().stream ();
194+
195+ aclrtMemcpyKind memcpy_kind;
196+ switch (direction) {
197+ case 0 :
198+ memcpy_kind = ACL_MEMCPY_HOST_TO_DEVICE;
199+ break ;
200+ case 1 :
201+ memcpy_kind = ACL_MEMCPY_DEVICE_TO_HOST;
202+ break ;
203+ case 2 :
204+ memcpy_kind = ACL_MEMCPY_DEVICE_TO_DEVICE;
205+ break ;
206+ default :
207+ TORCH_CHECK (false ,
208+ " swap_blocks_batch: invalid direction " , direction,
209+ " (expected 0=H2D, 1=D2H, 2=D2D)" );
210+ }
211+
212+ // =========================================================================
213+ // path 1: aclrtMemcpyBatchAsync (CANN 8.5+)
214+ // =========================================================================
215+ #if defined(CANN_MEMCPY_BATCH_ASYNC)
216+ if (memcpy_kind != ACL_MEMCPY_DEVICE_TO_DEVICE) {
217+ static_assert (sizeof (void *) == sizeof (int64_t ),
218+ " void* and int64_t must be the same size" );
219+ static_assert (sizeof (size_t ) == sizeof (int64_t ),
220+ " size_t and int64_t must be the same size" );
221+
222+ void ** dst_arr = reinterpret_cast <void **>(
223+ const_cast <int64_t *>(dst_data));
224+ void ** src_arr = reinterpret_cast <void **>(
225+ const_cast <int64_t *>(src_data));
226+ size_t * size_arr = reinterpret_cast <size_t *>(
227+ const_cast <int64_t *>(size_data));
228+ size_t * dest_maxs = size_arr;
229+
230+ // aclrtMemcpyBatchAttr uses srcLoc/dstLoc (aclrtMemLocation)
231+ // to specify memory locations, not aclrtMemcpyKind.
232+ int32_t device_id = 0 ;
233+ aclrtGetDevice (&device_id);
234+
235+ aclrtMemLocation host_loc = {};
236+ host_loc.type = ACL_MEM_LOCATION_TYPE_HOST;
237+ host_loc.id = 0 ;
238+
239+ aclrtMemLocation device_loc = {};
240+ device_loc.type = ACL_MEM_LOCATION_TYPE_DEVICE;
241+ device_loc.id = device_id;
242+
243+ aclrtMemcpyBatchAttr attr = {};
244+ if (memcpy_kind == ACL_MEMCPY_HOST_TO_DEVICE) {
245+ attr.srcLoc = host_loc;
246+ attr.dstLoc = device_loc;
247+ } else { // ACL_MEMCPY_DEVICE_TO_HOST
248+ attr.srcLoc = device_loc;
249+ attr.dstLoc = host_loc;
250+ }
251+
252+ size_t attrs_index = 0 ;
253+ size_t fail_index = 0 ;
254+
255+ aclError result = aclrtMemcpyBatchAsync (
256+ dst_arr, dest_maxs, src_arr, size_arr,
257+ static_cast <size_t >(n),
258+ &attr, &attrs_index, 1 ,
259+ &fail_index, stream);
260+
261+ TORCH_CHECK (result == ACL_SUCCESS,
262+ " aclrtMemcpyBatchAsync failed at index " , fail_index,
263+ " with error code " , result);
264+ return ;
265+ }
266+ #endif
267+
268+ // =========================================================================
269+ // path 2: aclrtMemcpyAsync
270+ // =========================================================================
271+ for (int64_t i = 0 ; i < n; i++) {
272+ void * dst = reinterpret_cast <void *>(dst_data[i]);
273+ const void * src = reinterpret_cast <const void *>(src_data[i]);
274+ size_t copy_size = static_cast <size_t >(size_data[i]);
275+
276+ aclError ret = aclrtMemcpyAsync (
277+ dst,
278+ copy_size,
279+ src,
280+ copy_size,
281+ memcpy_kind,
282+ stream);
283+
284+ TORCH_CHECK (ret == ACL_SUCCESS,
285+ " aclrtMemcpyAsync failed at index " , i,
286+ " with error code " , ret,
287+ " , src=" , src_data[i],
288+ " , dst=" , dst_data[i],
289+ " , size=" , size_data[i]);
290+ }
291+ }
292+
171293AscendType get_dtype_from_torch (at::ScalarType scalarType)
172294{
173295 if (scalarType == at::ScalarType::Float) {
@@ -962,6 +1084,11 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
9621084 ops.def (" swap_blocks(Tensor! x, Tensor! y, Tensor z) -> ()" );
9631085 ops.impl (" swap_blocks" , torch::kPrivateUse1 , &vllm_ascend::swap_blocks);
9641086
1087+ // swap_blocks_batch takes CPU tensors (int64 pointer/size arrays), not NPU
1088+ // tensors, so dispatch must be registered on the CPU backend. The function
1089+ // internally submits async memcpy on the current NPU stream.
1090+ ops.def (" swap_blocks_batch(Tensor x, Tensor y, Tensor z, int direction) -> ()" );
1091+ ops.impl (" swap_blocks_batch" , torch::kCPU , &vllm_ascend::swap_blocks_batch);
9651092 ops.def (" device_print(str msg) -> ()" );
9661093 ops.impl (" device_print" , c10::DispatchKey::CompositeExplicitAutograd,
9671094 static_cast <void (*)(c10::string_view)>(&vllm_ascend::device_print));
0 commit comments