Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

#include <functional>
#include <iostream>
#include <mutex>
#include <queue>
#include <string>
#include <unordered_set>
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Use std::unordered_map instead of std::unordered_set to store the serialized IPC handles for reuse across multiple sub-allocations within the same memory segment.

Suggested change
#include <unordered_set>
#include <unordered_map>


#include <cuda.h>
#include <cuda_runtime.h>
Expand Down Expand Up @@ -109,6 +111,9 @@ class NVLinkTransport : public Transport {
std::string machine_id_;
uint64_t async_memcpy_threshold_;
bool host_register_;

std::mutex register_mutex_;
std::unordered_set<uint64_t> registered_base_addrs_;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Change the tracking set to a map to store the serialized IPC handle associated with each base address. This is necessary to correctly populate the shm_path for subsequent sub-allocations in the same segment.

Suggested change
std::unordered_set<uint64_t> registered_base_addrs_;
std::unordered_map<uint64_t, std::string> registered_base_addrs_;

};
} // namespace tent
} // namespace mooncake
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,45 @@ Status NVLinkTransport::addMemoryBuffer(BufferDesc& desc,
// If the memory region is allocated using cuMemAlloc,
// we cannot use cudaIpcGetMemHandle, so skip it
if (options.type == MNNVL) return Status::OK();

// Resolve the true cudaMalloc base address. Caching allocators
// (e.g. PyTorch) sub-allocate tensors within larger cudaMalloc
// segments. cudaIpcGetMemHandle returns a handle for the whole
// segment, so we need to register at segment granularity.
CUdeviceptr base_ptr = 0;
size_t alloc_size = 0;
CUresult cu_err =
cuMemGetAddressRange(&base_ptr, &alloc_size, (CUdeviceptr)desc.addr);
if (cu_err != CUDA_SUCCESS) {
LOG(ERROR) << "NVLinkTransport: cuMemGetAddressRange failed for "
<< "addr 0x" << std::hex << desc.addr << std::dec
<< " (error " << cu_err << ")";
return Status::InternalError(
"cuMemGetAddressRange failed" LOC_MARK);
}

{
std::lock_guard<std::mutex> lock(register_mutex_);
if (registered_base_addrs_.count((uint64_t)base_ptr)) {
// Already registered this cudaMalloc block, just tag transport
desc.addr = (uint64_t)base_ptr;
desc.length = alloc_size;
desc.transports.push_back(TransportType::NVLINK);
return Status::OK();
}
}

cudaIpcMemHandle_t handle;
CHECK_CUDA(cudaIpcGetMemHandle(&handle, (void*)desc.addr));
CHECK_CUDA(cudaIpcGetMemHandle(&handle, (void*)base_ptr));
desc.addr = (uint64_t)base_ptr;
desc.length = alloc_size;
desc.shm_path =
serializeBinaryData(&handle, sizeof(cudaIpcMemHandle_t));

{
std::lock_guard<std::mutex> lock(register_mutex_);
registered_base_addrs_.insert((uint64_t)base_ptr);
}
Comment on lines +235 to +256
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation has a critical bug: when a base address is already registered (cache hit), it fails to populate desc.shm_path. This will cause the receiving side to fail during address relocation because it expects a valid IPC handle. Additionally, the check-then-act pattern outside the lock creates a race condition where multiple threads might redundantly call cudaIpcGetMemHandle for the same segment. This refactoring fixes both issues by using a map to store and retrieve the serialized handle while holding the lock.

        {
            std::lock_guard<std::mutex> lock(register_mutex_);
            auto it = registered_base_addrs_.find((uint64_t)base_ptr);
            if (it != registered_base_addrs_.end()) {
                desc.shm_path = it->second;
            } else {
                cudaIpcMemHandle_t handle;
                CHECK_CUDA(cudaIpcGetMemHandle(&handle, (void*)base_ptr));
                desc.shm_path = serializeBinaryData(&handle, sizeof(cudaIpcMemHandle_t));
                registered_base_addrs_[(uint64_t)base_ptr] = desc.shm_path;
            }
        }
        desc.addr = (uint64_t)base_ptr;
        desc.length = alloc_size;

} else if (location.type() == "cpu" ||
location.type() == kWildcardLocation) {
if (host_register_)
Expand All @@ -232,11 +267,33 @@ Status NVLinkTransport::addMemoryBuffer(BufferDesc& desc,
}

Status NVLinkTransport::removeMemoryBuffer(BufferDesc& desc) {
desc.shm_path.clear();
LocationParser location(desc.location);
if (location.type() == "cpu" && host_register_) {
if (location.type() == "cuda") {
// Resolve base the same way we did in addMemoryBuffer, so we
// remove the right entry even for sub-allocated addresses.
CUdeviceptr base_ptr = 0;
size_t alloc_size = 0;
CUresult cu_err =
cuMemGetAddressRange(&base_ptr, &alloc_size, (CUdeviceptr)desc.addr);

uint64_t key = desc.addr;
if (cu_err == CUDA_SUCCESS) {
key = (uint64_t)base_ptr;
} else {
LOG(WARNING) << "NVLinkTransport: cuMemGetAddressRange failed for "
<< "addr 0x" << std::hex << desc.addr << std::dec
<< " during removal (error " << cu_err
<< "). Memory may already be freed.";
}

{
std::lock_guard<std::mutex> lock(register_mutex_);
registered_base_addrs_.erase(key);
}
} else if (location.type() == "cpu" && host_register_) {
CHECK_CUDA(cudaHostUnregister((void*)desc.addr));
}
desc.shm_path.clear();
return Status::OK();
}

Expand Down
Loading