nixl refactor: new transfer design#40731
nixl refactor: new transfer design#40731ZhanqiuHu wants to merge 44 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the NIXL connector to use a plan-based transfer design, which unifies Dense and Mamba implementations into a model-agnostic execution path. By pre-generating an EngineTransferPlan during the handshake, the hot path is simplified and model-specific branching is removed. The review feedback highlights several opportunities to improve robustness, specifically by replacing assertions with explicit error handling for block count mismatches and adding divisibility checks when calculating logical blocks and chunk sizes to prevent potential data corruption or crashes in heterogeneous tensor parallel configurations.
| num_local_blocks = len(local_block_ids[i]) | ||
| if not self._is_mamba_group[i]: | ||
| assert num_local_blocks <= num_remote_blocks | ||
| # Partial prefix cache hit: just read uncomputed blocks. | ||
| # Skip mamba groups — their blocks represent full state (conv+ssm), | ||
| # not per-token data, so trimming would corrupt the transfer. | ||
| if num_local_blocks < num_remote_blocks and not self._is_mamba_group[i]: | ||
| assert num_local_blocks <= len(remote_group) | ||
| if num_local_blocks < len(remote_group): |
There was a problem hiding this comment.
The assertion assert num_local_blocks <= len(remote_group) can cause a hard crash of the worker if the producer sends fewer blocks than the consumer expects (e.g., due to a race condition or misconfiguration). It is safer to handle this as a transfer failure for the specific request, allowing the engine to continue processing other requests. Since _read_blocks is called within a try-except block in _read_blocks_for_req, raising a ValueError will be caught and handled gracefully.
| num_local_blocks = len(local_block_ids[i]) | |
| if not self._is_mamba_group[i]: | |
| assert num_local_blocks <= num_remote_blocks | |
| # Partial prefix cache hit: just read uncomputed blocks. | |
| # Skip mamba groups — their blocks represent full state (conv+ssm), | |
| # not per-token data, so trimming would corrupt the transfer. | |
| if num_local_blocks < num_remote_blocks and not self._is_mamba_group[i]: | |
| assert num_local_blocks <= len(remote_group) | |
| if num_local_blocks < len(remote_group): | |
| num_local_blocks = len(local_block_ids[i]) | |
| if num_local_blocks > len(remote_group): | |
| raise ValueError( | |
| f"Group {i}: local block count ({num_local_blocks}) " | |
| f"exceeds remote block count ({len(remote_group)})") | |
| if num_local_blocks < len(remote_group): |
| ratio = physical_blocks_per_logical | ||
| logical_blocks = num_blocks // ratio |
There was a problem hiding this comment.
The calculation of logical_blocks = num_blocks // ratio assumes that the total number of kernel blocks is a perfect multiple of the physical-to-logical ratio. If this invariant is violated, the descriptor IDs for SSM regions will be incorrectly computed, leading to memory corruption or incorrect data transfer. An explicit check should be added to verify this invariant.
| ratio = physical_blocks_per_logical | |
| logical_blocks = num_blocks // ratio | |
| ratio = physical_blocks_per_logical | |
| if num_blocks % ratio != 0: | |
| raise ValueError(f"num_blocks {num_blocks} is not a multiple of " | |
| f"physical_blocks_per_logical {ratio}") | |
| logical_blocks = num_blocks // ratio |
| if j < num_fa_descs: | ||
| chunk = local_len // fa_num_splits | ||
| handle.append((addr + fa_slot * chunk, chunk, dev)) | ||
| else: | ||
| chunk = local_len // ssm_num_splits | ||
| handle.append((addr + p_idx * chunk, chunk, dev)) |
There was a problem hiding this comment.
Using integer division to compute chunk sizes for split handles can lead to silent data loss if the total length is not perfectly divisible by the number of splits. While standard vLLM configurations usually satisfy this, heterogeneous TP scenarios with non-standard world sizes could trigger this issue. It is safer to explicitly check for divisibility.
| if j < num_fa_descs: | |
| chunk = local_len // fa_num_splits | |
| handle.append((addr + fa_slot * chunk, chunk, dev)) | |
| else: | |
| chunk = local_len // ssm_num_splits | |
| handle.append((addr + p_idx * chunk, chunk, dev)) | |
| if j < num_fa_descs: | |
| if local_len % fa_num_splits != 0: | |
| raise ValueError(f"FA descriptor length {local_len} is not " | |
| f"divisible by split count {fa_num_splits}") | |
| chunk = local_len // fa_num_splits | |
| handle.append((addr + fa_slot * chunk, chunk, dev)) | |
| else: | |
| if local_len % ssm_num_splits != 0: | |
| raise ValueError(f"SSM descriptor length {local_len} is not " | |
| f"divisible by split count {ssm_num_splits}") | |
| chunk = local_len // ssm_num_splits | |
| handle.append((addr + p_idx * chunk, chunk, dev)) |
4852347 to
fcf7418
Compare
fcf7418 to
a6e5266
Compare
NickLucche
left a comment
There was a problem hiding this comment.
Left some comments, thanks @ZhanqiuHu !
| # ------------------------------------------------------------------ | ||
| # Plan executors (static — no self access) | ||
| # ------------------------------------------------------------------ |
There was a problem hiding this comment.
this feels a little bit too "claudy"
| @dataclass(frozen=True) | ||
| class RegionPlan: | ||
| """Geometry for one descriptor region. | ||
|
|
||
| Everything needed to build NIXL descriptors and compute descriptor | ||
| IDs is baked in. The caller plugs in ``base_addr`` and | ||
| ``device_id`` when constructing the final descriptor tuples. |
There was a problem hiding this comment.
I am not fully convinced about this abstraction.
I am afraid this may actually be harder to work with rather than a basic region described by (base_addr, len).
Like do we care about keeping track of things like
- page_stride
- offset_in_page
once the starting address of the region has been computed?
'Cause if we don't, we might as well just store the address and len directly
There was a problem hiding this comment.
This is purely my personal preference, but I was thinking of moving the
geometry computation (stride, offset, descriptor size) from worker.py
into transfer_plan.py to reduce the worker code. RegionPlan packs
the output of that. Originally the functions included base_addr and
device_id, but then I wanted to reduce the arguments to
_build_fa_regions, build_fa_local_regions, and
build_mamba_local_regions, and base_addr and device_id are not used for block geometry computation.
I was also thinking of adding parameters like descs_per_block and desc_stride_bytes to RegionPlan,
so we can handle different cache groups with different block size ratios
(e.g, Gemma4 HeteroTP where SWA and FA have different tokens-per-block).
| handle.append((addr + p_idx * chunk, chunk, dev)) | ||
| result.append(handle) | ||
|
|
||
| return result |
There was a problem hiding this comment.
not that it matters much in terms of speed, but this whole method could yield handle here an be a generator
There was a problem hiding this comment.
Sounds good, will update _build_local_splits_from_plan to yield handle.
| all_descs: list[np.ndarray] = [] | ||
| for i, group in enumerate(block_ids): | ||
| group_arr = np.asarray(group) | ||
| spec_type = plan.group_spec_types[i] | ||
| if _is_attention_spec(spec_type): | ||
| fa_region_ids = np.arange(num_fa_regions)[:, None] | ||
| all_descs.append( | ||
| (fa_region_ids * num_blocks + group_arr[None, :]).flatten() | ||
| ) |
There was a problem hiding this comment.
I think we lost a nice optimization for non mamba models that used vectorized np ops only here
if not self._has_mamba:
block_ids = np.concatenate(block_ids)[None, :]
descs_ids = region_ids * num_blocks + block_ids
return descs_ids.flatten()
else:
# NOTE (NickLucche) SSM and Attention blocks regions can be exchanged
@ZhanqiuHu could you double check
There was a problem hiding this comment.
Yes, that's truth. I was thinking of removing the branch. I can add it back like this:
@staticmethod
def _compute_desc_ids_from_plan(
plan: EngineTransferPlan,
block_ids: BlockIds,
dst_num_blocks: int,
block_size_ratio: float | None,
physical_blocks_per_logical: int,
) -> np.ndarray:
"""Compute NIXL descriptor IDs for given block IDs."""
num_fa_regions = len(plan.fa_regions)
num_ssm_regions = len(plan.ssm_regions)
num_blocks = dst_num_blocks
if block_size_ratio is not None:
num_blocks = int(num_blocks * block_size_ratio)
ratio = physical_blocks_per_logical
logical_blocks = num_blocks // ratio
num_fa_descs = num_fa_regions * num_blocks
# All-attention fast path: single vectorized broadcast.
if num_ssm_regions == 0:
block_arr = np.concatenate(block_ids)[None, :]
region_ids = np.arange(num_fa_regions)[:, None]
return (region_ids * num_blocks + block_arr).flatten()
# NOTE (NickLucche) With HMA, every kv group has the same number
# of layers and layers from different groups share the same kv
# tensor. Therefore we compute desc IDs per group using the
# right stride:
# FA descs have num_blocks entries per region (kernel granularity),
# SSM descs have logical_blocks entries per region (no kernel
# splitting).
all_descs: list[np.ndarray] = []
for i, group in enumerate(block_ids):
group_arr = np.asarray(group)
if _is_attention_spec(plan.group_spec_types[i]):
fa_region_ids = np.arange(num_fa_regions)[:, None]
all_descs.append(
(fa_region_ids * num_blocks + group_arr[None, :]).flatten()
)
elif _is_ssm_spec(plan.group_spec_types[i]):
# NOTE (NickLucche) SSM and Attention block regions can
# be exchanged arbitrarily by manager. Therefore, descs
# are laid out as:
# [descs_fa (all regions) | descs_ssm (all regions)].
# num_fa_descs offset must be computed per-engine since
# P and D can have different num_blocks (and thus
# different FA desc counts).
ssm_region_ids = np.arange(num_ssm_regions)[:, None]
all_descs.append(
(
ssm_region_ids * logical_blocks
+ group_arr[None, :]
+ num_fa_descs
).flatten()
)
else:
raise ValueError(
f"Unknown spec type {plan.group_spec_types[i]} at index {i}"
)
return np.concatenate(all_descs)
| # Per-group ordered source ranks. Position = local piece index. | ||
| source_ranks_per_group: tuple[tuple[int, ...], ...] | ||
|
|
||
| # Superset of all source ranks (union of all groups). | ||
| all_source_ranks: tuple[int, ...] |
There was a problem hiding this comment.
I dont think it's very clear what "source_rank" is here..
There was a problem hiding this comment.
source_rank is suppose to mean the remote TP ranks that this local rank reads from.
source_ranks_per_group is the source_rank for each kv cache group (e.g., FA source ranks will be < mamba source ranks if FA is replicated and Mamba is sharded).
Should we renamed it to something else?
There was a problem hiding this comment.
let's just add this as comment above here
| group_spec_types: tuple[type[KVCacheSpec], ...], | ||
| local_physical_blocks_per_logical: int, | ||
| ) -> EngineTransferPlan: | ||
| """Generate transfer plan for dense (attention-only) models.""" |
There was a problem hiding this comment.
nit: very minor, but we may want to choose some other name or clarify that for dense we still encompass all non-mamba, including SW and DSA
There was a problem hiding this comment.
How about generate_pure_attention_plan() and
generate_ssm_attention_hybrid_plan()?
There was a problem hiding this comment.
let's ask claude for some more options here
| def _build_local_descs( | ||
| self, | ||
| base_addresses: list[int], | ||
| block_size_ratio: int, | ||
| ) -> list[tuple[int, int, int]]: |
There was a problem hiding this comment.
this method is only used once in register_local_xfer_handler.
I don't think there's a lot of value in added clarity in separating this snippet as a staticmethod
There was a problem hiding this comment.
Agree, will inline it.
| def _compute_read_specs_from_plan( | ||
| plan: EngineTransferPlan, | ||
| local_block_ids: BlockIds, | ||
| remote_block_ids: BlockIds, | ||
| ) -> list[ReadSpec]: | ||
| """Compute read specs from plan. |
There was a problem hiding this comment.
I am also not sure whether this should be a function, or an inline for
There was a problem hiding this comment.
Agree, will inline it.
| # ..but we still need to notify the other remote ranks that we | ||
| # have the blocks we need so they can update the request state. |
There was a problem hiding this comment.
nit: this continuation comment has lost its first part now. It was meant to be
# MLA opt: when P TP > D TP, only a single read is executed for
# the first remote rank (cache is duplicated)..
# ..but we still need to notify the other remote ranks that we
# have the blocks we need so they can update the request state.
but we can just rephrase
There was a problem hiding this comment.
The first part is actually before
if self.use_mla and tp_ratio < 0:
read_specs = read_specs[:1]
The change is already in main (main ref), probably introduced by a previous PR.
| self._read_blocks( | ||
| request_id=req_id, | ||
| dst_engine_id=meta.remote.engine_id, | ||
| remote_request_id=meta.remote.request_id, | ||
| local_block_ids=local_ids, | ||
| remote_block_ids=remote_ids, | ||
| local_block_ids=local_block_ids, | ||
| remote_block_ids=remote_block_ids, |
There was a problem hiding this comment.
not fully sure, but should we change _read_blocks interface to just be
def _read_blocks(
self,
dst_engine_id: str,
request_id: str,
read_spec: ReadSpec,
remote_request_id: str,
local_xfer_side_handle: int,
remote_xfer_side_handle: int,
):
There was a problem hiding this comment.
Yes, I think we can do
def _read_blocks(
self,
read_spec: ReadSpec,
dst_engine_id: str,
remote_request_id: str,
local_xfer_side_handle: int,
remote_xfer_side_handle: int,
):
local_block_ids = read_spec.local_block_ids
remote_block_ids = read_spec.remote_block_ids
remote_rank = read_spec.remote_rank
ZhanqiuHu
left a comment
There was a problem hiding this comment.
Addressed comments
| # Per-group ordered source ranks. Position = local piece index. | ||
| source_ranks_per_group: tuple[tuple[int, ...], ...] | ||
|
|
||
| # Superset of all source ranks (union of all groups). | ||
| all_source_ranks: tuple[int, ...] |
There was a problem hiding this comment.
let's just add this as comment above here
| group_spec_types: tuple[type[KVCacheSpec], ...], | ||
| local_physical_blocks_per_logical: int, | ||
| ) -> EngineTransferPlan: | ||
| """Generate transfer plan for dense (attention-only) models.""" |
There was a problem hiding this comment.
let's ask claude for some more options here
| all_descs: list[np.ndarray] = [] | ||
| for i, group in enumerate(block_ids): | ||
| group_arr = np.asarray(group) | ||
| spec_type = plan.group_spec_types[i] | ||
| if _is_attention_spec(spec_type): | ||
| fa_region_ids = np.arange(num_fa_regions)[:, None] | ||
| all_descs.append( | ||
| (fa_region_ids * num_blocks + group_arr[None, :]).flatten() | ||
| ) |
| Get the block length for one K/V element (K and V have the same size). | ||
|
|
||
| For FA and other backends, this is equal to the length of the whole | ||
| block, as K and V are in separate regions. | ||
| For FlashInfer, this is half the length of the whole block, as K and V | ||
| share the same region. | ||
| Similarly, for SSM-based models, state and conv are interleaved, but crucially | ||
| the their size differs. | ||
| Reference diagram: | ||
| KVCacheTensor (Shared) | ||
| / \\ | ||
| / \\ | ||
| / \\ | ||
| Attention (FlashInfer) View Mamba View | ||
| | | | ||
| | | | ||
| +-------------------+ +-------------------+ | ||
| | KVCacheTensor | | KVCacheTensor | | ||
| | | | | | ||
| |<----- page ------>| |<----- page ------->| | ||
| | size | | size | | ||
| | Key 0 | Val 0 | |Conv 0 | SSM 0 | | ||
| | Key 1 | Val 1 | |Conv 1 | SSM 1 | | ||
| | ... | ... | | ... | ... | | ||
| | Key N-2 | Val N-2 | |Conv N-2| SSM N-2 | | ||
| | Key N-1 | Val N-1 | |Conv N-1| SSM N-1 | | ||
| +-------------------+ +--------------------+ | ||
| |1st_split-2nd_split| |1st_split-2nd_split | | ||
| """ |
There was a problem hiding this comment.
I think we lost this whole diagram :( @ZhanqiuHu
| @dataclass(frozen=True) | ||
| class MambaEngineTransferInfo(EngineTransferInfo): | ||
| """Extends ``EngineTransferInfo`` with Mamba-hybrid transfer geometry. |
There was a problem hiding this comment.
nice cleanup in this file
a6e5266 to
51f297a
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
51f297a to
7881c46
Compare
NickLucche
left a comment
There was a problem hiding this comment.
I think this is a particularly complicated part of the codebase, thanks for the great work here @ZhanqiuHu trying to improve clarity and maintainability of it!
Happy to finally approve this PR. Left some comments, mostly around preserving some of the context we had in comments. Will push something to help the work around the clock
| self._physical_blocks_per_logical_kv_block = ( | ||
| self.block_size // kernel_block_size | ||
| ) |
There was a problem hiding this comment.
I think this might be cruft
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """TP mapping computation for NIXL KV cache transfers.""" |
There was a problem hiding this comment.
I think this file is tp_mapping now :)
| # With homogeneous TP, D pulls the whole kv cache from corresponding | ||
| # rank. With heterogeneous TP, prepare the descriptors by splitting the | ||
| # P KV cache along kv_head dim, of D worker's kv_head size (D>P). | ||
| # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. | ||
|
|
||
| # Register all remote blocks, but only the corresponding kv heads. |
There was a problem hiding this comment.
this might still be good context
| ) | ||
|
|
||
| if transfer_topo.is_kv_layout_blocks_first: | ||
| # With FlashInfer index V separately to allow head splitting. |
| # Separate and interleave K/V regions to maintain the same | ||
| # descs ordering. This is needed for selecting contiguous heads | ||
| # when split across TP ranks. |
| # Mamba conv state is always TP-sharded, even when attention KV | ||
| # is replicated (num_kv_heads < tp_size). |
| # ---- Mamba-HMA per-engine state (only used when self._has_mamba) ---- | ||
| # NOTE (ZhanqiuHu): _physical_blocks_per_logical MUST be per-engine. | ||
| # physical_blocks_per_logical = ceil((conv_bytes + ssm_bytes) / block_len) | ||
| # where conv/ssm bytes are per-TP-rank (dimension-sharded). With | ||
| # heterogeneous TP the per-rank sizes differ, so the ratio differs: | ||
| # e.g. Nemotron 30B: P(TP=4) → 131, D(TP=1) → 261. |
There was a problem hiding this comment.
I think this is also good context
f23d68c to
98c3207
Compare
Signed-off-by: Zhanqiu Hu <zhu@redhat.com> Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com> Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com> Signed-off-by: ZhanqiuHu <zhu@redhat.com>
…blocks_per_logical removal Signed-off-by: Zhanqiu Hu <zhu@redhat.com> Signed-off-by: ZhanqiuHu <zhu@redhat.com>
…_policy and use_mla, set num_descs Signed-off-by: Zhanqiu Hu <zhu@redhat.com> Signed-off-by: ZhanqiuHu <zhu@redhat.com>
…nsfer_info Add physical_blocks_per_logical to TransferTopology and pass transfer_topo directly to build_engine_transfer_info, reducing the method's parameter count from 10 to 6. Signed-off-by: Zhanqiu Hu <zhu@redhat.com> Signed-off-by: ZhanqiuHu <zhu@redhat.com>
…11→4) Signed-off-by: Zhanqiu Hu <zhu@redhat.com> Signed-off-by: ZhanqiuHu <zhu@redhat.com>
…_id (7→4) Signed-off-by: Zhanqiu Hu <zhu@redhat.com> Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com> Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com> Signed-off-by: ZhanqiuHu <zhu@redhat.com>
…to_kernel_block_ids Move model-specific block ID expansion and trimming logic out of worker.py hot paths into a unified logical_to_kernel_block_ids() in transfer_plan.py. Per-request functions now consume only the pre-computed EngineTransferPlan (model-agnostic); model awareness is confined to init and plan generation (if/else dense vs mamba). - Add transfer_plan.py with plan generators, executors, and logical_to_kernel_block_ids (per-group physical_per_logical). - Generate EngineTransferPlan during handshake (generate_dense_plan or generate_mamba_plan), stored in _transfer_plans dict. - Replace _logical_to_remote_kernel_block_ids in _read_blocks_for_req with plan-based logical_to_kernel_block_ids (no model branching). - Make block trimming in _read_blocks unconditional (SSM groups are no-op due to shared block table). - Thin-wrap _logical_to_kernel_block_ids for local expansion, delegating to the same unified function. - Add _conv_decomp to worker init for mamba plan generation. Signed-off-by: Zhanqiu Hu <zhu@redhat.com> Signed-off-by: ZhanqiuHu <zhu@redhat.com>
The unified logical_to_kernel_block_ids had two bugs: 1. Dense remote: used ratio=1 instead of local kernel ratio 2. Mamba FA remote: used same value for stride and count, but old code used remote_ratio as stride, local_ratio as count Restore the original two-method design: - _logical_to_kernel_block_ids: local expansion (same as main) - _logical_to_remote_kernel_block_ids: remote mamba expansion (same as main) Only difference from main: remote_ratio comes from plan.remote_physical_blocks_per_logical instead of self._mamba_phys_ratio[engine_id]. Signed-off-by: Zhanqiu Hu <zhu@redhat.com> Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Add remote_expansion_stride to EngineTransferPlan so the per-request hot path always calls _logical_to_remote_kernel_block_ids with the plan's stride, removing the last model-specific branch from _read_blocks_for_req. Dense plan: stride = local_physical_blocks_per_logical (stride == count). Mamba plan: stride = remote_physical_blocks_per_logical (stride != count). Signed-off-by: Zhanqiu Hu <zhu@redhat.com> Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Replace the bool-based is_mamba_group with a GroupKind enum (FA, SWA, MAMBA, GDN) so the transfer layer can dispatch on group type without model-specific branching. Shared behavior is captured by properties (is_attention, is_ssm) — no code duplication when adding new group types. Unsupported KVCacheSpec types raise NotImplementedError. Signed-off-by: Zhanqiu Hu <zhu@redhat.com> Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com> Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com> Signed-off-by: ZhanqiuHu <zhu@redhat.com>
…ake field Signed-off-by: Zhanqiu Hu <zhu@redhat.com> Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com> Signed-off-by: ZhanqiuHu <zhu@redhat.com>
…nsion Signed-off-by: Zhanqiu Hu <zhu@redhat.com> Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Head branch was pushed to by a user without write access
98c3207 to
1232865
Compare
Refactor 3/N