99import time
1010import uuid
1111from collections import defaultdict
12+ from collections .abc import Iterator
1213from concurrent .futures import Future , ThreadPoolExecutor
1314from typing import TYPE_CHECKING , Any , cast
1415
8990class NixlConnectorWorker :
9091 """Implementation of Worker side methods"""
9192
92- # ------------------------------------------------------------------
93- # Plan executors (static — no self access)
94- # ------------------------------------------------------------------
95-
9693 @staticmethod
9794 def _build_remote_descs_from_plan (
9895 plan : EngineTransferPlan ,
@@ -130,6 +127,12 @@ def _compute_desc_ids_from_plan(
130127
131128 num_fa_descs = num_fa_regions * num_blocks
132129
130+ # All-attention fast path: single vectorized broadcast.
131+ if num_ssm_regions == 0 :
132+ block_arr = np .concatenate (block_ids )[None , :]
133+ region_ids = np .arange (num_fa_regions )[:, None ]
134+ return (region_ids * num_blocks + block_arr ).flatten ()
135+
133136 # NOTE (NickLucche) With HMA, every kv group has the same number
134137 # of layers and layers from different groups share the same kv
135138 # tensor. Therefore we compute desc IDs per group using the
@@ -140,13 +143,12 @@ def _compute_desc_ids_from_plan(
140143 all_descs : list [np .ndarray ] = []
141144 for i , group in enumerate (block_ids ):
142145 group_arr = np .asarray (group )
143- spec_type = plan .group_spec_types [i ]
144- if _is_attention_spec (spec_type ):
146+ if _is_attention_spec (plan .group_spec_types [i ]):
145147 fa_region_ids = np .arange (num_fa_regions )[:, None ]
146148 all_descs .append (
147149 (fa_region_ids * num_blocks + group_arr [None , :]).flatten ()
148150 )
149- elif _is_ssm_spec (spec_type ):
151+ elif _is_ssm_spec (plan . group_spec_types [ i ] ):
150152 # NOTE (NickLucche) SSM and Attention block regions can
151153 # be exchanged arbitrarily by manager. Therefore, descs
152154 # are laid out as:
@@ -163,47 +165,18 @@ def _compute_desc_ids_from_plan(
163165 ).flatten ()
164166 )
165167 else :
166- raise ValueError (f"Unknown spec type { spec_type } at index { i } " )
168+ raise ValueError (
169+ f"Unknown spec type { plan .group_spec_types [i ]} at index { i } "
170+ )
167171
168172 return np .concatenate (all_descs )
169173
170- @staticmethod
171- def _compute_read_specs_from_plan (
172- plan : EngineTransferPlan ,
173- local_block_ids : BlockIds ,
174- remote_block_ids : BlockIds ,
175- ) -> list [ReadSpec ]:
176- """Compute read specs from plan.
177-
178- For each source rank, includes only the groups whose
179- source_ranks_per_group contains that rank.
180- """
181- num_groups = len (local_block_ids )
182- return [
183- ReadSpec (
184- remote_rank = rank ,
185- local_block_ids = [
186- list (local_block_ids [g ])
187- if rank in plan .source_ranks_per_group [g ]
188- else []
189- for g in range (num_groups )
190- ],
191- remote_block_ids = [
192- list (remote_block_ids [g ])
193- if rank in plan .source_ranks_per_group [g ]
194- else []
195- for g in range (num_groups )
196- ],
197- )
198- for rank in plan .all_source_ranks
199- ]
200-
201174 @staticmethod
202175 def _build_local_splits_from_plan (
203176 plan : EngineTransferPlan ,
204177 src_blocks_data : list [tuple [int , int , int ]],
205178 num_fa_descs : int ,
206- ) -> list [list [tuple [int , int , int ]]]:
179+ ) -> Iterator [list [tuple [int , int , int ]]]:
207180 """Build split handle data for P_TP > D_TP scenario.
208181
209182 num_fa_descs is the boundary between FA and SSM descriptors.
@@ -217,8 +190,6 @@ def _build_local_splits_from_plan(
217190 has_ssm_descs = num_fa_descs < len (src_blocks_data )
218191 ssm_num_splits = len (plan .source_ranks_per_group [- 1 ]) if has_ssm_descs else 0
219192
220- result : list [list [tuple [int , int , int ]]] = []
221-
222193 for p_idx , p_rank in enumerate (plan .all_source_ranks ):
223194 fa_slot = plan .rank_to_attention_slot .get (p_rank , 0 )
224195
@@ -230,48 +201,7 @@ def _build_local_splits_from_plan(
230201 else :
231202 chunk = local_len // ssm_num_splits
232203 handle .append ((addr + p_idx * chunk , chunk , dev ))
233- result .append (handle )
234-
235- return result
236-
237- def _build_local_descs (
238- self ,
239- base_addresses : list [int ],
240- block_size_ratio : int ,
241- ) -> list [tuple [int , int , int ]]:
242- """Build local (src) descriptor tuples for NIXL registration."""
243- assert self .transfer_topo is not None
244- fa_regions = build_fa_local_regions (
245- self .num_blocks ,
246- block_size_ratio ,
247- self .block_len_per_layer ,
248- self .transfer_topo .is_kv_layout_blocks_first ,
249- )
250- if self ._has_mamba :
251- # TODO (ZhanqiuHu): For homogeneous TP (tp_ratio == 1), the
252- # 3-read split is unnecessary — a single conv desc per block
253- # suffices. Consider adding a fast path. Currently we always
254- # register 4 regions because local descs are created before
255- # knowing the remote TP.
256- assert self ._conv_decomp is not None
257- mamba_regions = build_mamba_local_regions (
258- self .block_len_per_layer ,
259- self ._logical_num_blocks ,
260- block_size_ratio ,
261- self ._conv_decomp ,
262- self ._mamba_ssm_size ,
263- self ._physical_blocks_per_logical_kv_block ,
264- )
265- else :
266- mamba_regions = []
267-
268- result : list [tuple [int , int , int ]] = []
269- for region in fa_regions + mamba_regions :
270- base = base_addresses [region .layer_idx ]
271- for blk in range (region .num_blocks ):
272- addr = base + blk * region .page_stride + region .offset_in_page
273- result .append ((addr , region .descriptor_bytes , self .device_id ))
274- return result
204+ yield handle
275205
276206 def __init__ (
277207 self ,
@@ -1054,7 +984,37 @@ def register_local_xfer_handler(
1054984 block_size_ratio = self .block_size // block_size
1055985 local_base_addresses = self .kv_caches_base_addr [self .engine_id ][self .tp_rank ]
1056986
1057- blocks_data = self ._build_local_descs (local_base_addresses , block_size_ratio )
987+ fa_regions = build_fa_local_regions (
988+ self .num_blocks ,
989+ block_size_ratio ,
990+ self .block_len_per_layer ,
991+ self .transfer_topo .is_kv_layout_blocks_first ,
992+ )
993+ if self ._has_mamba :
994+ # TODO (ZhanqiuHu): For homogeneous TP (tp_ratio == 1), the
995+ # 3-read split is unnecessary — a single conv desc per block
996+ # suffices. Consider adding a fast path. Currently we always
997+ # register 4 regions because local descs are created before
998+ # knowing the remote TP.
999+ assert self ._conv_decomp is not None
1000+ mamba_regions = build_mamba_local_regions (
1001+ self .block_len_per_layer ,
1002+ self ._logical_num_blocks ,
1003+ block_size_ratio ,
1004+ self ._conv_decomp ,
1005+ self ._mamba_ssm_size ,
1006+ self ._physical_blocks_per_logical_kv_block ,
1007+ )
1008+ else :
1009+ mamba_regions = []
1010+
1011+ blocks_data : list [tuple [int , int , int ]] = []
1012+ for region in fa_regions + mamba_regions :
1013+ base = local_base_addresses [region .layer_idx ]
1014+ for blk in range (region .num_blocks ):
1015+ addr = base + blk * region .page_stride + region .offset_in_page
1016+ blocks_data .append ((addr , region .descriptor_bytes , self .device_id ))
1017+
10581018 logger .debug (
10591019 "Created %s blocks for src engine %s and rank %s on device id %s" ,
10601020 len (blocks_data ),
@@ -1777,11 +1737,26 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
17771737 plan .remote_expansion_stride ,
17781738 )
17791739 remote_block_ids = meta .remote .block_ids
1780- read_specs = self ._compute_read_specs_from_plan (
1781- plan ,
1782- local_block_ids = meta .local_physical_block_ids ,
1783- remote_block_ids = remote_block_ids ,
1784- )
1740+ local_block_ids = meta .local_physical_block_ids
1741+ num_groups = len (local_block_ids )
1742+ read_specs = [
1743+ ReadSpec (
1744+ remote_rank = rank ,
1745+ local_block_ids = [
1746+ list (local_block_ids [g ])
1747+ if rank in plan .source_ranks_per_group [g ]
1748+ else []
1749+ for g in range (num_groups )
1750+ ],
1751+ remote_block_ids = [
1752+ list (remote_block_ids [g ])
1753+ if rank in plan .source_ranks_per_group [g ]
1754+ else []
1755+ for g in range (num_groups )
1756+ ],
1757+ )
1758+ for rank in plan .all_source_ranks
1759+ ]
17851760
17861761 # D may have to perform multiple reads from different remote ranks.
17871762 # MLA opt: when P TP > D TP, only a single read is executed for
@@ -1790,15 +1765,12 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
17901765 read_specs = read_specs [:1 ]
17911766
17921767 for i , spec in enumerate (read_specs ):
1793- remote_rank = spec .remote_rank
1794- local_block_ids = spec .local_block_ids
1795- remote_block_ids = spec .remote_block_ids
17961768 remote_block_size = remote_info .remote_block_size
17971769 logger .debug (
17981770 "Remote agent %s available, calling _read_blocks"
17991771 " on remote rank %s with remote block size %s for req %s" ,
18001772 meta .remote .engine_id ,
1801- remote_rank ,
1773+ spec . remote_rank ,
18021774 remote_block_size ,
18031775 req_id ,
18041776 )
@@ -1817,16 +1789,14 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
18171789
18181790 # Destination handle: remote_engine_id -> remote_rank -> handle.
18191791 remote_xfer_side_handle = self .dst_xfer_side_handles [meta .remote .engine_id ][
1820- remote_rank
1792+ spec . remote_rank
18211793 ]
18221794
18231795 self ._read_blocks (
1796+ read_spec = spec ,
18241797 request_id = req_id ,
18251798 dst_engine_id = meta .remote .engine_id ,
18261799 remote_request_id = meta .remote .request_id ,
1827- local_block_ids = local_block_ids ,
1828- remote_block_ids = remote_block_ids ,
1829- remote_rank = remote_rank ,
18301800 local_xfer_side_handle = local_xfer_side_handle ,
18311801 remote_xfer_side_handle = remote_xfer_side_handle ,
18321802 )
@@ -1843,12 +1813,10 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
18431813
18441814 def _read_blocks (
18451815 self ,
1846- local_block_ids : BlockIds ,
1847- remote_block_ids : BlockIds ,
1816+ read_spec : ReadSpec ,
18481817 dst_engine_id : str ,
18491818 request_id : str ,
18501819 remote_request_id : str ,
1851- remote_rank : int ,
18521820 local_xfer_side_handle : int ,
18531821 remote_xfer_side_handle : int ,
18541822 ):
@@ -1857,6 +1825,10 @@ def _read_blocks(
18571825 a single remote worker.
18581826 """
18591827 assert self .transfer_topo is not None
1828+ remote_rank = read_spec .remote_rank
1829+ local_block_ids = read_spec .local_block_ids
1830+ remote_block_ids = read_spec .remote_block_ids
1831+
18601832 plan = self ._transfer_plans [dst_engine_id ]
18611833 remote_info = self .transfer_topo .get_engine_info (dst_engine_id )
18621834 block_size_ratio = self .transfer_topo .block_size_ratio (
0 commit comments