Skip to content

[TENT] Fix NVLink IPC address for sub-allocated GPU tensors#1831

Merged
alogfans merged 2 commits intokvcache-ai:mainfrom
he-yufeng:fix/tent-nvlink-ipc-base-addr
Apr 8, 2026
Merged

[TENT] Fix NVLink IPC address for sub-allocated GPU tensors#1831
alogfans merged 2 commits intokvcache-ai:mainfrom
he-yufeng:fix/tent-nvlink-ipc-base-addr

Conversation

@he-yufeng
Copy link
Copy Markdown
Contributor

Summary

Fixes #1829

Port of the IPC base-address fix from #1622 to the TENT NVLink transport code path.

Root cause: When desc.addr is a sub-allocation within a larger cudaMalloc segment (typical with PyTorch's caching allocator), cudaIpcGetMemHandle() returns a handle for the entire segment. But the metadata records the sub-allocation address, so the receiving side computes an incorrect offset during relocateSharedMemoryAddress(), leading to silent IPC failures in intra-node P/D disaggregation.

Fix: Call cuMemGetAddressRange() before cudaIpcGetMemHandle() to resolve the true cudaMalloc base address and allocation size. Register the IPC handle at segment granularity, and track already-registered bases via registered_base_addrs_ to avoid duplicate registrations. Apply the same base-address resolution in removeMemoryBuffer().

This mirrors the approach already merged in the non-TENT path (IntraNodeNvlinkTransport::registerLocalMemory()), adapted to TENT's BufferDesc-based API.

Changes

  • nvlink_transport.cpp: addMemoryBuffer() now resolves base address via cuMemGetAddressRange, uses base for IPC handle, deduplicates via tracking set
  • nvlink_transport.cpp: removeMemoryBuffer() resolves base address before clearing tracking state
  • nvlink_transport.h: Added std::mutex + std::unordered_set<uint64_t> for registered_base_addrs_ tracking

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a caching mechanism for CUDA memory registration in the NVLink transport to handle sub-allocations from caching allocators like PyTorch. It uses cuMemGetAddressRange to identify base segments and tracks them to avoid redundant IPC handle creation. However, the current implementation contains a critical bug where the IPC handle (shm_path) is not populated for cached segments, which will cause failures on the receiving side. Additionally, the reviewer suggests replacing the unordered_set with an unordered_map to store and reuse serialized IPC handles, which also addresses a potential race condition in the registration logic.

Comment on lines +235 to +256
{
std::lock_guard<std::mutex> lock(register_mutex_);
if (registered_base_addrs_.count((uint64_t)base_ptr)) {
// Already registered this cudaMalloc block, just tag transport
desc.addr = (uint64_t)base_ptr;
desc.length = alloc_size;
desc.transports.push_back(TransportType::NVLINK);
return Status::OK();
}
}

cudaIpcMemHandle_t handle;
CHECK_CUDA(cudaIpcGetMemHandle(&handle, (void*)desc.addr));
CHECK_CUDA(cudaIpcGetMemHandle(&handle, (void*)base_ptr));
desc.addr = (uint64_t)base_ptr;
desc.length = alloc_size;
desc.shm_path =
serializeBinaryData(&handle, sizeof(cudaIpcMemHandle_t));

{
std::lock_guard<std::mutex> lock(register_mutex_);
registered_base_addrs_.insert((uint64_t)base_ptr);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation has a critical bug: when a base address is already registered (cache hit), it fails to populate desc.shm_path. This will cause the receiving side to fail during address relocation because it expects a valid IPC handle. Additionally, the check-then-act pattern outside the lock creates a race condition where multiple threads might redundantly call cudaIpcGetMemHandle for the same segment. This refactoring fixes both issues by using a map to store and retrieve the serialized handle while holding the lock.

        {
            std::lock_guard<std::mutex> lock(register_mutex_);
            auto it = registered_base_addrs_.find((uint64_t)base_ptr);
            if (it != registered_base_addrs_.end()) {
                desc.shm_path = it->second;
            } else {
                cudaIpcMemHandle_t handle;
                CHECK_CUDA(cudaIpcGetMemHandle(&handle, (void*)base_ptr));
                desc.shm_path = serializeBinaryData(&handle, sizeof(cudaIpcMemHandle_t));
                registered_base_addrs_[(uint64_t)base_ptr] = desc.shm_path;
            }
        }
        desc.addr = (uint64_t)base_ptr;
        desc.length = alloc_size;

#include <mutex>
#include <queue>
#include <string>
#include <unordered_set>
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Use std::unordered_map instead of std::unordered_set to store the serialized IPC handles for reuse across multiple sub-allocations within the same memory segment.

Suggested change
#include <unordered_set>
#include <unordered_map>

bool host_register_;

std::mutex register_mutex_;
std::unordered_set<uint64_t> registered_base_addrs_;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Change the tracking set to a map to store the serialized IPC handle associated with each base address. This is necessary to correctly populate the shm_path for subsequent sub-allocations in the same segment.

Suggested change
std::unordered_set<uint64_t> registered_base_addrs_;
std::unordered_map<uint64_t, std::string> registered_base_addrs_;

@codecov-commenter
Copy link
Copy Markdown

⚠️ Please install the 'codecov app svg image' to ensure uploads and comments are reliably processed by Codecov.

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

Copy link
Copy Markdown
Collaborator

@alogfans alogfans left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@alogfans alogfans merged commit a7518f3 into kvcache-ai:main Apr 8, 2026
16 of 17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug][TENT]: SGLang P/D intra-node communication fails due to wrong IPC address

3 participants