Skip to content

Commit 51f297a

Browse files
committed
fix
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
1 parent 7d91ede commit 51f297a

2 files changed

Lines changed: 97 additions & 102 deletions

File tree

vllm/distributed/kv_transfer/kv_connector/v1/nixl/transfer_plan.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,28 @@
44
55
Data structures, plan generators, and local descriptor builders
66
for NIXL KV cache transfers.
7+
8+
Reference diagram::
9+
10+
KVCacheTensor (Shared)
11+
/ \\
12+
/ \\
13+
/ \\
14+
Attention (FlashInfer) View Mamba View
15+
| |
16+
| |
17+
+-------------------+ +-------------------+
18+
| KVCacheTensor | | KVCacheTensor |
19+
| | | |
20+
|<----- page ------>| |<----- page ------->|
21+
| size | | size |
22+
| Key 0 | Val 0 | |Conv 0 | SSM 0 |
23+
| Key 1 | Val 1 | |Conv 1 | SSM 1 |
24+
| ... | ... | | ... | ... |
25+
| Key N-2 | Val N-2 | |Conv N-2| SSM N-2 |
26+
| Key N-1 | Val N-1 | |Conv N-1| SSM N-1 |
27+
+-------------------+ +--------------------+
28+
|1st_split-2nd_split| |1st_split-2nd_split |
729
"""
830

931
from __future__ import annotations
@@ -85,7 +107,8 @@ class EngineTransferPlan:
85107
# Per-group KVCacheSpec type — used for descriptor indexing.
86108
group_spec_types: tuple[type[KVCacheSpec], ...]
87109

88-
# Per-group ordered source ranks. Position = local piece index.
110+
# Remote TP ranks that this local rank reads from, per group.
111+
# Position = local piece index.
89112
source_ranks_per_group: tuple[tuple[int, ...], ...]
90113

91114
# Superset of all source ranks (union of all groups).

vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py

Lines changed: 73 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import time
1010
import uuid
1111
from collections import defaultdict
12+
from collections.abc import Iterator
1213
from concurrent.futures import Future, ThreadPoolExecutor
1314
from typing import TYPE_CHECKING, Any, cast
1415

@@ -89,10 +90,6 @@
8990
class NixlConnectorWorker:
9091
"""Implementation of Worker side methods"""
9192

92-
# ------------------------------------------------------------------
93-
# Plan executors (static — no self access)
94-
# ------------------------------------------------------------------
95-
9693
@staticmethod
9794
def _build_remote_descs_from_plan(
9895
plan: EngineTransferPlan,
@@ -130,6 +127,12 @@ def _compute_desc_ids_from_plan(
130127

131128
num_fa_descs = num_fa_regions * num_blocks
132129

130+
# All-attention fast path: single vectorized broadcast.
131+
if num_ssm_regions == 0:
132+
block_arr = np.concatenate(block_ids)[None, :]
133+
region_ids = np.arange(num_fa_regions)[:, None]
134+
return (region_ids * num_blocks + block_arr).flatten()
135+
133136
# NOTE (NickLucche) With HMA, every kv group has the same number
134137
# of layers and layers from different groups share the same kv
135138
# tensor. Therefore we compute desc IDs per group using the
@@ -140,13 +143,12 @@ def _compute_desc_ids_from_plan(
140143
all_descs: list[np.ndarray] = []
141144
for i, group in enumerate(block_ids):
142145
group_arr = np.asarray(group)
143-
spec_type = plan.group_spec_types[i]
144-
if _is_attention_spec(spec_type):
146+
if _is_attention_spec(plan.group_spec_types[i]):
145147
fa_region_ids = np.arange(num_fa_regions)[:, None]
146148
all_descs.append(
147149
(fa_region_ids * num_blocks + group_arr[None, :]).flatten()
148150
)
149-
elif _is_ssm_spec(spec_type):
151+
elif _is_ssm_spec(plan.group_spec_types[i]):
150152
# NOTE (NickLucche) SSM and Attention block regions can
151153
# be exchanged arbitrarily by manager. Therefore, descs
152154
# are laid out as:
@@ -163,47 +165,18 @@ def _compute_desc_ids_from_plan(
163165
).flatten()
164166
)
165167
else:
166-
raise ValueError(f"Unknown spec type {spec_type} at index {i}")
168+
raise ValueError(
169+
f"Unknown spec type {plan.group_spec_types[i]} at index {i}"
170+
)
167171

168172
return np.concatenate(all_descs)
169173

170-
@staticmethod
171-
def _compute_read_specs_from_plan(
172-
plan: EngineTransferPlan,
173-
local_block_ids: BlockIds,
174-
remote_block_ids: BlockIds,
175-
) -> list[ReadSpec]:
176-
"""Compute read specs from plan.
177-
178-
For each source rank, includes only the groups whose
179-
source_ranks_per_group contains that rank.
180-
"""
181-
num_groups = len(local_block_ids)
182-
return [
183-
ReadSpec(
184-
remote_rank=rank,
185-
local_block_ids=[
186-
list(local_block_ids[g])
187-
if rank in plan.source_ranks_per_group[g]
188-
else []
189-
for g in range(num_groups)
190-
],
191-
remote_block_ids=[
192-
list(remote_block_ids[g])
193-
if rank in plan.source_ranks_per_group[g]
194-
else []
195-
for g in range(num_groups)
196-
],
197-
)
198-
for rank in plan.all_source_ranks
199-
]
200-
201174
@staticmethod
202175
def _build_local_splits_from_plan(
203176
plan: EngineTransferPlan,
204177
src_blocks_data: list[tuple[int, int, int]],
205178
num_fa_descs: int,
206-
) -> list[list[tuple[int, int, int]]]:
179+
) -> Iterator[list[tuple[int, int, int]]]:
207180
"""Build split handle data for P_TP > D_TP scenario.
208181
209182
num_fa_descs is the boundary between FA and SSM descriptors.
@@ -217,8 +190,6 @@ def _build_local_splits_from_plan(
217190
has_ssm_descs = num_fa_descs < len(src_blocks_data)
218191
ssm_num_splits = len(plan.source_ranks_per_group[-1]) if has_ssm_descs else 0
219192

220-
result: list[list[tuple[int, int, int]]] = []
221-
222193
for p_idx, p_rank in enumerate(plan.all_source_ranks):
223194
fa_slot = plan.rank_to_attention_slot.get(p_rank, 0)
224195

@@ -230,48 +201,7 @@ def _build_local_splits_from_plan(
230201
else:
231202
chunk = local_len // ssm_num_splits
232203
handle.append((addr + p_idx * chunk, chunk, dev))
233-
result.append(handle)
234-
235-
return result
236-
237-
def _build_local_descs(
238-
self,
239-
base_addresses: list[int],
240-
block_size_ratio: int,
241-
) -> list[tuple[int, int, int]]:
242-
"""Build local (src) descriptor tuples for NIXL registration."""
243-
assert self.transfer_topo is not None
244-
fa_regions = build_fa_local_regions(
245-
self.num_blocks,
246-
block_size_ratio,
247-
self.block_len_per_layer,
248-
self.transfer_topo.is_kv_layout_blocks_first,
249-
)
250-
if self._has_mamba:
251-
# TODO (ZhanqiuHu): For homogeneous TP (tp_ratio == 1), the
252-
# 3-read split is unnecessary — a single conv desc per block
253-
# suffices. Consider adding a fast path. Currently we always
254-
# register 4 regions because local descs are created before
255-
# knowing the remote TP.
256-
assert self._conv_decomp is not None
257-
mamba_regions = build_mamba_local_regions(
258-
self.block_len_per_layer,
259-
self._logical_num_blocks,
260-
block_size_ratio,
261-
self._conv_decomp,
262-
self._mamba_ssm_size,
263-
self._physical_blocks_per_logical_kv_block,
264-
)
265-
else:
266-
mamba_regions = []
267-
268-
result: list[tuple[int, int, int]] = []
269-
for region in fa_regions + mamba_regions:
270-
base = base_addresses[region.layer_idx]
271-
for blk in range(region.num_blocks):
272-
addr = base + blk * region.page_stride + region.offset_in_page
273-
result.append((addr, region.descriptor_bytes, self.device_id))
274-
return result
204+
yield handle
275205

276206
def __init__(
277207
self,
@@ -1054,7 +984,37 @@ def register_local_xfer_handler(
1054984
block_size_ratio = self.block_size // block_size
1055985
local_base_addresses = self.kv_caches_base_addr[self.engine_id][self.tp_rank]
1056986

1057-
blocks_data = self._build_local_descs(local_base_addresses, block_size_ratio)
987+
fa_regions = build_fa_local_regions(
988+
self.num_blocks,
989+
block_size_ratio,
990+
self.block_len_per_layer,
991+
self.transfer_topo.is_kv_layout_blocks_first,
992+
)
993+
if self._has_mamba:
994+
# TODO (ZhanqiuHu): For homogeneous TP (tp_ratio == 1), the
995+
# 3-read split is unnecessary — a single conv desc per block
996+
# suffices. Consider adding a fast path. Currently we always
997+
# register 4 regions because local descs are created before
998+
# knowing the remote TP.
999+
assert self._conv_decomp is not None
1000+
mamba_regions = build_mamba_local_regions(
1001+
self.block_len_per_layer,
1002+
self._logical_num_blocks,
1003+
block_size_ratio,
1004+
self._conv_decomp,
1005+
self._mamba_ssm_size,
1006+
self._physical_blocks_per_logical_kv_block,
1007+
)
1008+
else:
1009+
mamba_regions = []
1010+
1011+
blocks_data: list[tuple[int, int, int]] = []
1012+
for region in fa_regions + mamba_regions:
1013+
base = local_base_addresses[region.layer_idx]
1014+
for blk in range(region.num_blocks):
1015+
addr = base + blk * region.page_stride + region.offset_in_page
1016+
blocks_data.append((addr, region.descriptor_bytes, self.device_id))
1017+
10581018
logger.debug(
10591019
"Created %s blocks for src engine %s and rank %s on device id %s",
10601020
len(blocks_data),
@@ -1777,11 +1737,26 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
17771737
plan.remote_expansion_stride,
17781738
)
17791739
remote_block_ids = meta.remote.block_ids
1780-
read_specs = self._compute_read_specs_from_plan(
1781-
plan,
1782-
local_block_ids=meta.local_physical_block_ids,
1783-
remote_block_ids=remote_block_ids,
1784-
)
1740+
local_block_ids = meta.local_physical_block_ids
1741+
num_groups = len(local_block_ids)
1742+
read_specs = [
1743+
ReadSpec(
1744+
remote_rank=rank,
1745+
local_block_ids=[
1746+
list(local_block_ids[g])
1747+
if rank in plan.source_ranks_per_group[g]
1748+
else []
1749+
for g in range(num_groups)
1750+
],
1751+
remote_block_ids=[
1752+
list(remote_block_ids[g])
1753+
if rank in plan.source_ranks_per_group[g]
1754+
else []
1755+
for g in range(num_groups)
1756+
],
1757+
)
1758+
for rank in plan.all_source_ranks
1759+
]
17851760

17861761
# D may have to perform multiple reads from different remote ranks.
17871762
# MLA opt: when P TP > D TP, only a single read is executed for
@@ -1790,15 +1765,12 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
17901765
read_specs = read_specs[:1]
17911766

17921767
for i, spec in enumerate(read_specs):
1793-
remote_rank = spec.remote_rank
1794-
local_block_ids = spec.local_block_ids
1795-
remote_block_ids = spec.remote_block_ids
17961768
remote_block_size = remote_info.remote_block_size
17971769
logger.debug(
17981770
"Remote agent %s available, calling _read_blocks"
17991771
" on remote rank %s with remote block size %s for req %s",
18001772
meta.remote.engine_id,
1801-
remote_rank,
1773+
spec.remote_rank,
18021774
remote_block_size,
18031775
req_id,
18041776
)
@@ -1817,16 +1789,14 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
18171789

18181790
# Destination handle: remote_engine_id -> remote_rank -> handle.
18191791
remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote.engine_id][
1820-
remote_rank
1792+
spec.remote_rank
18211793
]
18221794

18231795
self._read_blocks(
1796+
read_spec=spec,
18241797
request_id=req_id,
18251798
dst_engine_id=meta.remote.engine_id,
18261799
remote_request_id=meta.remote.request_id,
1827-
local_block_ids=local_block_ids,
1828-
remote_block_ids=remote_block_ids,
1829-
remote_rank=remote_rank,
18301800
local_xfer_side_handle=local_xfer_side_handle,
18311801
remote_xfer_side_handle=remote_xfer_side_handle,
18321802
)
@@ -1843,12 +1813,10 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
18431813

18441814
def _read_blocks(
18451815
self,
1846-
local_block_ids: BlockIds,
1847-
remote_block_ids: BlockIds,
1816+
read_spec: ReadSpec,
18481817
dst_engine_id: str,
18491818
request_id: str,
18501819
remote_request_id: str,
1851-
remote_rank: int,
18521820
local_xfer_side_handle: int,
18531821
remote_xfer_side_handle: int,
18541822
):
@@ -1857,6 +1825,10 @@ def _read_blocks(
18571825
a single remote worker.
18581826
"""
18591827
assert self.transfer_topo is not None
1828+
remote_rank = read_spec.remote_rank
1829+
local_block_ids = read_spec.local_block_ids
1830+
remote_block_ids = read_spec.remote_block_ids
1831+
18601832
plan = self._transfer_plans[dst_engine_id]
18611833
remote_info = self.transfer_topo.get_engine_info(dst_engine_id)
18621834
block_size_ratio = self.transfer_topo.block_size_ratio(

0 commit comments

Comments
 (0)