Skip to content

nixl refactor: new transfer design#40731

Open
ZhanqiuHu wants to merge 44 commits intovllm-project:mainfrom
ZhanqiuHu:nixl-refactor-plan-based-poc
Open

nixl refactor: new transfer design#40731
ZhanqiuHu wants to merge 44 commits intovllm-project:mainfrom
ZhanqiuHu:nixl-refactor-plan-based-poc

Conversation

@ZhanqiuHu
Copy link
Copy Markdown
Contributor

@ZhanqiuHu ZhanqiuHu commented Apr 23, 2026

Refactor 3/N

@ZhanqiuHu ZhanqiuHu changed the title nixl refactor: plan-based transfer design nixl refactor: new transfer design Apr 23, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the NIXL connector to use a plan-based transfer design, which unifies Dense and Mamba implementations into a model-agnostic execution path. By pre-generating an EngineTransferPlan during the handshake, the hot path is simplified and model-specific branching is removed. The review feedback highlights several opportunities to improve robustness, specifically by replacing assertions with explicit error handling for block count mismatches and adding divisibility checks when calculating logical blocks and chunk sizes to prevent potential data corruption or crashes in heterogeneous tensor parallel configurations.

Comment on lines 1788 to +1790
num_local_blocks = len(local_block_ids[i])
if not self._is_mamba_group[i]:
assert num_local_blocks <= num_remote_blocks
# Partial prefix cache hit: just read uncomputed blocks.
# Skip mamba groups — their blocks represent full state (conv+ssm),
# not per-token data, so trimming would corrupt the transfer.
if num_local_blocks < num_remote_blocks and not self._is_mamba_group[i]:
assert num_local_blocks <= len(remote_group)
if num_local_blocks < len(remote_group):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The assertion assert num_local_blocks <= len(remote_group) can cause a hard crash of the worker if the producer sends fewer blocks than the consumer expects (e.g., due to a race condition or misconfiguration). It is safer to handle this as a transfer failure for the specific request, allowing the engine to continue processing other requests. Since _read_blocks is called within a try-except block in _read_blocks_for_req, raising a ValueError will be caught and handled gracefully.

Suggested change
num_local_blocks = len(local_block_ids[i])
if not self._is_mamba_group[i]:
assert num_local_blocks <= num_remote_blocks
# Partial prefix cache hit: just read uncomputed blocks.
# Skip mamba groups — their blocks represent full state (conv+ssm),
# not per-token data, so trimming would corrupt the transfer.
if num_local_blocks < num_remote_blocks and not self._is_mamba_group[i]:
assert num_local_blocks <= len(remote_group)
if num_local_blocks < len(remote_group):
num_local_blocks = len(local_block_ids[i])
if num_local_blocks > len(remote_group):
raise ValueError(
f"Group {i}: local block count ({num_local_blocks}) "
f"exceeds remote block count ({len(remote_group)})")
if num_local_blocks < len(remote_group):

Comment on lines +565 to +566
ratio = physical_blocks_per_logical
logical_blocks = num_blocks // ratio
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calculation of logical_blocks = num_blocks // ratio assumes that the total number of kernel blocks is a perfect multiple of the physical-to-logical ratio. If this invariant is violated, the descriptor IDs for SSM regions will be incorrectly computed, leading to memory corruption or incorrect data transfer. An explicit check should be added to verify this invariant.

Suggested change
ratio = physical_blocks_per_logical
logical_blocks = num_blocks // ratio
ratio = physical_blocks_per_logical
if num_blocks % ratio != 0:
raise ValueError(f"num_blocks {num_blocks} is not a multiple of "
f"physical_blocks_per_logical {ratio}")
logical_blocks = num_blocks // ratio

Comment on lines +646 to +651
if j < num_fa_descs:
chunk = local_len // fa_num_splits
handle.append((addr + fa_slot * chunk, chunk, dev))
else:
chunk = local_len // ssm_num_splits
handle.append((addr + p_idx * chunk, chunk, dev))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using integer division to compute chunk sizes for split handles can lead to silent data loss if the total length is not perfectly divisible by the number of splits. While standard vLLM configurations usually satisfy this, heterogeneous TP scenarios with non-standard world sizes could trigger this issue. It is safer to explicitly check for divisibility.

Suggested change
if j < num_fa_descs:
chunk = local_len // fa_num_splits
handle.append((addr + fa_slot * chunk, chunk, dev))
else:
chunk = local_len // ssm_num_splits
handle.append((addr + p_idx * chunk, chunk, dev))
if j < num_fa_descs:
if local_len % fa_num_splits != 0:
raise ValueError(f"FA descriptor length {local_len} is not "
f"divisible by split count {fa_num_splits}")
chunk = local_len // fa_num_splits
handle.append((addr + fa_slot * chunk, chunk, dev))
else:
if local_len % ssm_num_splits != 0:
raise ValueError(f"SSM descriptor length {local_len} is not "
f"divisible by split count {ssm_num_splits}")
chunk = local_len // ssm_num_splits
handle.append((addr + p_idx * chunk, chunk, dev))

@ZhanqiuHu ZhanqiuHu force-pushed the nixl-refactor-plan-based-poc branch 2 times, most recently from 4852347 to fcf7418 Compare April 27, 2026 15:41
@ZhanqiuHu ZhanqiuHu marked this pull request as ready for review April 27, 2026 18:01
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments, thanks @ZhanqiuHu !

Comment on lines +91 to +93
# ------------------------------------------------------------------
# Plan executors (static — no self access)
# ------------------------------------------------------------------
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels a little bit too "claudy"

Comment on lines +54 to +60
@dataclass(frozen=True)
class RegionPlan:
"""Geometry for one descriptor region.

Everything needed to build NIXL descriptors and compute descriptor
IDs is baked in. The caller plugs in ``base_addr`` and
``device_id`` when constructing the final descriptor tuples.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not fully convinced about this abstraction.
I am afraid this may actually be harder to work with rather than a basic region described by (base_addr, len).

Like do we care about keeping track of things like

  • page_stride
  • offset_in_page

once the starting address of the region has been computed?
'Cause if we don't, we might as well just store the address and len directly

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is purely my personal preference, but I was thinking of moving the
geometry computation (stride, offset, descriptor size) from worker.py
into transfer_plan.py to reduce the worker code. RegionPlan packs
the output of that. Originally the functions included base_addr and
device_id, but then I wanted to reduce the arguments to
_build_fa_regions, build_fa_local_regions, and
build_mamba_local_regions, and base_addr and device_id are not used for block geometry computation.

I was also thinking of adding parameters like descs_per_block and desc_stride_bytes to RegionPlan,
so we can handle different cache groups with different block size ratios
(e.g, Gemma4 HeteroTP where SWA and FA have different tokens-per-block).

Comment on lines +252 to +255
handle.append((addr + p_idx * chunk, chunk, dev))
result.append(handle)

return result
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not that it matters much in terms of speed, but this whole method could yield handle here an be a generator

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, will update _build_local_splits_from_plan to yield handle.

Comment on lines +143 to +151
all_descs: list[np.ndarray] = []
for i, group in enumerate(block_ids):
group_arr = np.asarray(group)
spec_type = plan.group_spec_types[i]
if _is_attention_spec(spec_type):
fa_region_ids = np.arange(num_fa_regions)[:, None]
all_descs.append(
(fa_region_ids * num_blocks + group_arr[None, :]).flatten()
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we lost a nice optimization for non mamba models that used vectorized np ops only here

        if not self._has_mamba:
            block_ids = np.concatenate(block_ids)[None, :]
            descs_ids = region_ids * num_blocks + block_ids
            return descs_ids.flatten()
        else:
            # NOTE (NickLucche) SSM and Attention blocks regions can be exchanged

@ZhanqiuHu could you double check

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's truth. I was thinking of removing the branch. I can add it back like this:

@staticmethod
def _compute_desc_ids_from_plan(
    plan: EngineTransferPlan,
    block_ids: BlockIds,
    dst_num_blocks: int,
    block_size_ratio: float | None,
    physical_blocks_per_logical: int,
) -> np.ndarray:
    """Compute NIXL descriptor IDs for given block IDs."""
    num_fa_regions = len(plan.fa_regions)
    num_ssm_regions = len(plan.ssm_regions)

    num_blocks = dst_num_blocks
    if block_size_ratio is not None:
        num_blocks = int(num_blocks * block_size_ratio)
    ratio = physical_blocks_per_logical
    logical_blocks = num_blocks // ratio

    num_fa_descs = num_fa_regions * num_blocks

    # All-attention fast path: single vectorized broadcast.
    if num_ssm_regions == 0:
        block_arr = np.concatenate(block_ids)[None, :]
        region_ids = np.arange(num_fa_regions)[:, None]
        return (region_ids * num_blocks + block_arr).flatten()

    # NOTE (NickLucche) With HMA, every kv group has the same number
    # of layers and layers from different groups share the same kv
    # tensor.  Therefore we compute desc IDs per group using the
    # right stride:
    # FA descs have num_blocks entries per region (kernel granularity),
    # SSM descs have logical_blocks entries per region (no kernel
    # splitting).
    all_descs: list[np.ndarray] = []
    for i, group in enumerate(block_ids):
        group_arr = np.asarray(group)
        if _is_attention_spec(plan.group_spec_types[i]):
            fa_region_ids = np.arange(num_fa_regions)[:, None]
            all_descs.append(
                (fa_region_ids * num_blocks + group_arr[None, :]).flatten()
            )
        elif _is_ssm_spec(plan.group_spec_types[i]):
            # NOTE (NickLucche) SSM and Attention block regions can
            # be exchanged arbitrarily by manager.  Therefore, descs
            # are laid out as:
            #   [descs_fa (all regions) | descs_ssm (all regions)].
            # num_fa_descs offset must be computed per-engine since
            # P and D can have different num_blocks (and thus
            # different FA desc counts).
            ssm_region_ids = np.arange(num_ssm_regions)[:, None]
            all_descs.append(
                (
                    ssm_region_ids * logical_blocks
                    + group_arr[None, :]
                    + num_fa_descs
                ).flatten()
            )
        else:
            raise ValueError(
                f"Unknown spec type {plan.group_spec_types[i]} at index {i}"
            )

    return np.concatenate(all_descs)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's ok

Comment on lines +88 to +92
# Per-group ordered source ranks. Position = local piece index.
source_ranks_per_group: tuple[tuple[int, ...], ...]

# Superset of all source ranks (union of all groups).
all_source_ranks: tuple[int, ...]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think it's very clear what "source_rank" is here..

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

source_rank is suppose to mean the remote TP ranks that this local rank reads from.

source_ranks_per_group is the source_rank for each kv cache group (e.g., FA source ranks will be < mamba source ranks if FA is replicated and Mamba is sharded).

Should we renamed it to something else?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's just add this as comment above here

group_spec_types: tuple[type[KVCacheSpec], ...],
local_physical_blocks_per_logical: int,
) -> EngineTransferPlan:
"""Generate transfer plan for dense (attention-only) models."""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: very minor, but we may want to choose some other name or clarify that for dense we still encompass all non-mamba, including SW and DSA

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about generate_pure_attention_plan() and
generate_ssm_attention_hybrid_plan()?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's ask claude for some more options here

Comment on lines +257 to +261
def _build_local_descs(
self,
base_addresses: list[int],
block_size_ratio: int,
) -> list[tuple[int, int, int]]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this method is only used once in register_local_xfer_handler.
I don't think there's a lot of value in added clarity in separating this snippet as a staticmethod

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, will inline it.

Comment on lines +174 to +179
def _compute_read_specs_from_plan(
plan: EngineTransferPlan,
local_block_ids: BlockIds,
remote_block_ids: BlockIds,
) -> list[ReadSpec]:
"""Compute read specs from plan.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also not sure whether this should be a function, or an inline for

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, will inline it.

Comment on lines +1853 to +1854
# ..but we still need to notify the other remote ranks that we
# have the blocks we need so they can update the request state.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this continuation comment has lost its first part now. It was meant to be

            # MLA opt: when P TP > D TP, only a single read is executed for
            # the first remote rank (cache is duplicated)..
            # ..but we still need to notify the other remote ranks that we
            # have the blocks we need so they can update the request state.

but we can just rephrase

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first part is actually before

if self.use_mla and tp_ratio < 0:
            read_specs = read_specs[:1]

The change is already in main (main ref), probably introduced by a previous PR.

Comment on lines +1841 to +1846
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote.engine_id,
remote_request_id=meta.remote.request_id,
local_block_ids=local_ids,
remote_block_ids=remote_ids,
local_block_ids=local_block_ids,
remote_block_ids=remote_block_ids,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not fully sure, but should we change _read_blocks interface to just be

    def _read_blocks(
        self,
        dst_engine_id: str,
        request_id: str,
        read_spec: ReadSpec, 
        remote_request_id: str,
        local_xfer_side_handle: int,
        remote_xfer_side_handle: int,
    ):

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think we can do

def _read_blocks(
    self,
    read_spec: ReadSpec,
    dst_engine_id: str,
    remote_request_id: str,  
    local_xfer_side_handle: int,  
    remote_xfer_side_handle: int,  
):
    local_block_ids = read_spec.local_block_ids
    remote_block_ids = read_spec.remote_block_ids
    remote_rank = read_spec.remote_rank

Copy link
Copy Markdown
Contributor Author

@ZhanqiuHu ZhanqiuHu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed comments

Comment on lines +88 to +92
# Per-group ordered source ranks. Position = local piece index.
source_ranks_per_group: tuple[tuple[int, ...], ...]

# Superset of all source ranks (union of all groups).
all_source_ranks: tuple[int, ...]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's just add this as comment above here

group_spec_types: tuple[type[KVCacheSpec], ...],
local_physical_blocks_per_logical: int,
) -> EngineTransferPlan:
"""Generate transfer plan for dense (attention-only) models."""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's ask claude for some more options here

Comment on lines +143 to +151
all_descs: list[np.ndarray] = []
for i, group in enumerate(block_ids):
group_arr = np.asarray(group)
spec_type = plan.group_spec_types[i]
if _is_attention_spec(spec_type):
fa_region_ids = np.arange(num_fa_regions)[:, None]
all_descs.append(
(fa_region_ids * num_blocks + group_arr[None, :]).flatten()
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's ok

Comment on lines -2267 to -2295
Get the block length for one K/V element (K and V have the same size).

For FA and other backends, this is equal to the length of the whole
block, as K and V are in separate regions.
For FlashInfer, this is half the length of the whole block, as K and V
share the same region.
Similarly, for SSM-based models, state and conv are interleaved, but crucially
the their size differs.
Reference diagram:
KVCacheTensor (Shared)
/ \\
/ \\
/ \\
Attention (FlashInfer) View Mamba View
| |
| |
+-------------------+ +-------------------+
| KVCacheTensor | | KVCacheTensor |
| | | |
|<----- page ------>| |<----- page ------->|
| size | | size |
| Key 0 | Val 0 | |Conv 0 | SSM 0 |
| Key 1 | Val 1 | |Conv 1 | SSM 1 |
| ... | ... | | ... | ... |
| Key N-2 | Val N-2 | |Conv N-2| SSM N-2 |
| Key N-1 | Val N-1 | |Conv N-1| SSM N-1 |
+-------------------+ +--------------------+
|1st_split-2nd_split| |1st_split-2nd_split |
"""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we lost this whole diagram :( @ZhanqiuHu

Comment on lines -389 to -391
@dataclass(frozen=True)
class MambaEngineTransferInfo(EngineTransferInfo):
"""Extends ``EngineTransferInfo`` with Mamba-hybrid transfer geometry.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice cleanup in this file

@ZhanqiuHu ZhanqiuHu force-pushed the nixl-refactor-plan-based-poc branch from a6e5266 to 51f297a Compare May 4, 2026 18:11
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 4, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ZhanqiuHu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 4, 2026
@ZhanqiuHu ZhanqiuHu force-pushed the nixl-refactor-plan-based-poc branch from 51f297a to 7881c46 Compare May 4, 2026 18:17
@mergify mergify Bot removed the needs-rebase label May 4, 2026
Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a particularly complicated part of the codebase, thanks for the great work here @ZhanqiuHu trying to improve clarity and maintainability of it!

Happy to finally approve this PR. Left some comments, mostly around preserving some of the context we had in comments. Will push something to help the work around the clock

Comment on lines +867 to +869
self._physical_blocks_per_logical_kv_block = (
self.block_size // kernel_block_size
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might be cruft

Comment on lines +1 to +3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TP mapping computation for NIXL KV cache transfers."""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this file is tp_mapping now :)

Comment on lines -1269 to -1274
# With homogeneous TP, D pulls the whole kv cache from corresponding
# rank. With heterogeneous TP, prepare the descriptors by splitting the
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].

# Register all remote blocks, but only the corresponding kv heads.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this might still be good context

)

if transfer_topo.is_kv_layout_blocks_first:
# With FlashInfer index V separately to allow head splitting.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here too

Comment on lines -1061 to -1063
# Separate and interleave K/V regions to maintain the same
# descs ordering. This is needed for selecting contiguous heads
# when split across TP ranks.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Comment on lines -964 to -965
# Mamba conv state is always TP-sharded, even when attention KV
# is replicated (num_kv_heads < tp_size).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:(

Comment on lines -271 to -276
# ---- Mamba-HMA per-engine state (only used when self._has_mamba) ----
# NOTE (ZhanqiuHu): _physical_blocks_per_logical MUST be per-engine.
# physical_blocks_per_logical = ceil((conv_bytes + ssm_bytes) / block_len)
# where conv/ssm bytes are per-TP-rank (dimension-sharded). With
# heterogeneous TP the per-rank sizes differ, so the ratio differs:
# e.g. Nemotron 30B: P(TP=4) → 131, D(TP=1) → 261.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is also good context

@NickLucche NickLucche added the ready ONLY add when PR is ready to merge/full CI is needed label May 5, 2026
@ZhanqiuHu ZhanqiuHu force-pushed the nixl-refactor-plan-based-poc branch from f23d68c to 98c3207 Compare May 5, 2026 15:43
@NickLucche NickLucche enabled auto-merge (squash) May 5, 2026 17:13
ZhanqiuHu added 3 commits May 5, 2026 14:30
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>

Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>

Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>

Signed-off-by: ZhanqiuHu <zhu@redhat.com>
ZhanqiuHu and others added 28 commits May 5, 2026 14:30
…blocks_per_logical removal

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>

Signed-off-by: ZhanqiuHu <zhu@redhat.com>
…_policy and use_mla, set num_descs

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>

Signed-off-by: ZhanqiuHu <zhu@redhat.com>
…nsfer_info

Add physical_blocks_per_logical to TransferTopology and pass
transfer_topo directly to build_engine_transfer_info, reducing the
method's parameter count from 10 to 6.

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>

Signed-off-by: ZhanqiuHu <zhu@redhat.com>
…11→4)

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>

Signed-off-by: ZhanqiuHu <zhu@redhat.com>
…_id (7→4)

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>

Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>

Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>

Signed-off-by: ZhanqiuHu <zhu@redhat.com>
…to_kernel_block_ids

Move model-specific block ID expansion and trimming logic out of
worker.py hot paths into a unified logical_to_kernel_block_ids()
in transfer_plan.py. Per-request functions now consume only the
pre-computed EngineTransferPlan (model-agnostic); model awareness
is confined to init and plan generation (if/else dense vs mamba).

- Add transfer_plan.py with plan generators, executors, and
  logical_to_kernel_block_ids (per-group physical_per_logical).
- Generate EngineTransferPlan during handshake (generate_dense_plan
  or generate_mamba_plan), stored in _transfer_plans dict.
- Replace _logical_to_remote_kernel_block_ids in _read_blocks_for_req
  with plan-based logical_to_kernel_block_ids (no model branching).
- Make block trimming in _read_blocks unconditional (SSM groups are
  no-op due to shared block table).
- Thin-wrap _logical_to_kernel_block_ids for local expansion,
  delegating to the same unified function.
- Add _conv_decomp to worker init for mamba plan generation.

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>

Signed-off-by: ZhanqiuHu <zhu@redhat.com>
The unified logical_to_kernel_block_ids had two bugs:
1. Dense remote: used ratio=1 instead of local kernel ratio
2. Mamba FA remote: used same value for stride and count,
   but old code used remote_ratio as stride, local_ratio as count

Restore the original two-method design:
- _logical_to_kernel_block_ids: local expansion (same as main)
- _logical_to_remote_kernel_block_ids: remote mamba expansion (same as main)

Only difference from main: remote_ratio comes from
plan.remote_physical_blocks_per_logical instead of
self._mamba_phys_ratio[engine_id].

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>

Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Add remote_expansion_stride to EngineTransferPlan so the per-request
hot path always calls _logical_to_remote_kernel_block_ids with the
plan's stride, removing the last model-specific branch from
_read_blocks_for_req.

Dense plan: stride = local_physical_blocks_per_logical (stride == count).
Mamba plan: stride = remote_physical_blocks_per_logical (stride != count).

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>

Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Replace the bool-based is_mamba_group with a GroupKind enum
(FA, SWA, MAMBA, GDN) so the transfer layer can dispatch on
group type without model-specific branching. Shared behavior
is captured by properties (is_attention, is_ssm) — no code
duplication when adding new group types.

Unsupported KVCacheSpec types raise NotImplementedError.

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>

Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>

Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>

Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>

Signed-off-by: ZhanqiuHu <zhu@redhat.com>
…ake field

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>

Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>

Signed-off-by: ZhanqiuHu <zhu@redhat.com>
…nsion

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>

Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
auto-merge was automatically disabled May 5, 2026 18:31

Head branch was pushed to by a user without write access

@ZhanqiuHu ZhanqiuHu force-pushed the nixl-refactor-plan-based-poc branch from 98c3207 to 1232865 Compare May 5, 2026 18:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants