@@ -149,6 +149,17 @@ def __init__(
149149 # list of CUDA events available for re-use
150150 self ._event_pool : list [torch .Event ] = []
151151
152+ # Pre-compute base pointers and block sizes for batch copies.
153+ self ._src_base_ptrs = np .array (
154+ [t .data_ptr () for t in self .src_tensors ], dtype = np .int64
155+ )
156+ self ._dst_base_ptrs = np .array (
157+ [t .data_ptr () for t in self .dst_tensors ], dtype = np .int64
158+ )
159+ self ._block_size_in_bytes_arr = np .array (
160+ self .tensor_block_size_in_bytes , dtype = np .int64
161+ )
162+
152163 def transfer_async (self , job_id : int , transfer_spec : TransferSpec ) -> bool :
153164 src_spec , dst_spec = transfer_spec
154165 assert isinstance (src_spec , BlockIDsLoadStoreSpec )
@@ -165,15 +176,35 @@ def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool:
165176
166177 assert dst_sub_block_count == src_sub_block_count - src_sub_blocks_to_skip
167178
168- src_to_dst = np .empty ((dst_sub_block_count , 2 ), dtype = np .int64 )
179+ src_block_ids = np .empty (dst_sub_block_count , dtype = np .int64 )
180+ dst_block_ids = np .empty (dst_sub_block_count , dtype = np .int64 )
169181 expand_block_ids (
170182 src_blocks ,
171183 self .src_block_size_factor ,
172- src_to_dst [:, 0 ] ,
184+ src_block_ids ,
173185 skip_count = src_sub_blocks_to_skip ,
174186 )
175- expand_block_ids (dst_blocks , self .dst_block_size_factor , src_to_dst [:, 1 ])
176- src_to_dst_tensor = torch .from_numpy (src_to_dst )
187+ expand_block_ids (dst_blocks , self .dst_block_size_factor , dst_block_ids )
188+
189+ # Build flat pointer arrays for all tensors × all block pairs.
190+ num_pairs = dst_sub_block_count
191+ num_tensors = len (self .src_tensors )
192+ total = num_pairs * num_tensors
193+
194+ all_src = np .empty (total , dtype = np .int64 )
195+ all_dst = np .empty (total , dtype = np .int64 )
196+ all_sizes = np .empty (total , dtype = np .int64 )
197+
198+ for t_idx , bsz in enumerate (self ._block_size_in_bytes_arr ):
199+ start = t_idx * num_pairs
200+ end = start + num_pairs
201+ all_src [start :end ] = self ._src_base_ptrs [t_idx ] + src_block_ids * bsz
202+ all_dst [start :end ] = self ._dst_base_ptrs [t_idx ] + dst_block_ids * bsz
203+ all_sizes [start :end ] = bsz
204+
205+ batch_src = torch .from_numpy (all_src )
206+ batch_dst = torch .from_numpy (all_dst )
207+ batch_sizes = torch .from_numpy (all_sizes )
177208
178209 stream = self ._stream_pool .pop () if self ._stream_pool else torch .cuda .Stream ()
179210 start_event = (
@@ -197,17 +228,8 @@ def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool:
197228 stream .wait_event (last_event )
198229 with torch .cuda .stream (stream ):
199230 start_event .record (stream )
200- for src_tensor , dst_tensor , block_size_in_bytes in zip (
201- self .src_tensors ,
202- self .dst_tensors ,
203- self .tensor_block_size_in_bytes ,
204- ):
205- ops .swap_blocks (
206- src_tensor ,
207- dst_tensor ,
208- block_size_in_bytes ,
209- src_to_dst_tensor ,
210- )
231+ if total > 0 :
232+ ops .swap_blocks_batch (batch_src , batch_dst , batch_sizes )
211233 end_event .record (stream )
212234
213235 self ._transfer_events [job_id ] = end_event
0 commit comments