-
Notifications
You must be signed in to change notification settings - Fork 721
[TENT] Fix NVLink IPC address for sub-allocated GPU tensors #1831
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -17,8 +17,10 @@ | |||||
|
|
||||||
| #include <functional> | ||||||
| #include <iostream> | ||||||
| #include <mutex> | ||||||
| #include <queue> | ||||||
| #include <string> | ||||||
| #include <unordered_set> | ||||||
|
|
||||||
| #include <cuda.h> | ||||||
| #include <cuda_runtime.h> | ||||||
|
|
@@ -109,6 +111,9 @@ class NVLinkTransport : public Transport { | |||||
| std::string machine_id_; | ||||||
| uint64_t async_memcpy_threshold_; | ||||||
| bool host_register_; | ||||||
|
|
||||||
| std::mutex register_mutex_; | ||||||
| std::unordered_set<uint64_t> registered_base_addrs_; | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Change the tracking set to a map to store the serialized IPC handle associated with each base address. This is necessary to correctly populate the
Suggested change
|
||||||
| }; | ||||||
| } // namespace tent | ||||||
| } // namespace mooncake | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -215,10 +215,45 @@ Status NVLinkTransport::addMemoryBuffer(BufferDesc& desc, | |
| // If the memory region is allocated using cuMemAlloc, | ||
| // we cannot use cudaIpcGetMemHandle, so skip it | ||
| if (options.type == MNNVL) return Status::OK(); | ||
|
|
||
| // Resolve the true cudaMalloc base address. Caching allocators | ||
| // (e.g. PyTorch) sub-allocate tensors within larger cudaMalloc | ||
| // segments. cudaIpcGetMemHandle returns a handle for the whole | ||
| // segment, so we need to register at segment granularity. | ||
| CUdeviceptr base_ptr = 0; | ||
| size_t alloc_size = 0; | ||
| CUresult cu_err = | ||
| cuMemGetAddressRange(&base_ptr, &alloc_size, (CUdeviceptr)desc.addr); | ||
| if (cu_err != CUDA_SUCCESS) { | ||
| LOG(ERROR) << "NVLinkTransport: cuMemGetAddressRange failed for " | ||
| << "addr 0x" << std::hex << desc.addr << std::dec | ||
| << " (error " << cu_err << ")"; | ||
| return Status::InternalError( | ||
| "cuMemGetAddressRange failed" LOC_MARK); | ||
| } | ||
|
|
||
| { | ||
| std::lock_guard<std::mutex> lock(register_mutex_); | ||
| if (registered_base_addrs_.count((uint64_t)base_ptr)) { | ||
| // Already registered this cudaMalloc block, just tag transport | ||
| desc.addr = (uint64_t)base_ptr; | ||
| desc.length = alloc_size; | ||
| desc.transports.push_back(TransportType::NVLINK); | ||
| return Status::OK(); | ||
| } | ||
| } | ||
|
|
||
| cudaIpcMemHandle_t handle; | ||
| CHECK_CUDA(cudaIpcGetMemHandle(&handle, (void*)desc.addr)); | ||
| CHECK_CUDA(cudaIpcGetMemHandle(&handle, (void*)base_ptr)); | ||
| desc.addr = (uint64_t)base_ptr; | ||
| desc.length = alloc_size; | ||
| desc.shm_path = | ||
| serializeBinaryData(&handle, sizeof(cudaIpcMemHandle_t)); | ||
|
|
||
| { | ||
| std::lock_guard<std::mutex> lock(register_mutex_); | ||
| registered_base_addrs_.insert((uint64_t)base_ptr); | ||
| } | ||
|
Comment on lines
+235
to
+256
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current implementation has a critical bug: when a base address is already registered (cache hit), it fails to populate {
std::lock_guard<std::mutex> lock(register_mutex_);
auto it = registered_base_addrs_.find((uint64_t)base_ptr);
if (it != registered_base_addrs_.end()) {
desc.shm_path = it->second;
} else {
cudaIpcMemHandle_t handle;
CHECK_CUDA(cudaIpcGetMemHandle(&handle, (void*)base_ptr));
desc.shm_path = serializeBinaryData(&handle, sizeof(cudaIpcMemHandle_t));
registered_base_addrs_[(uint64_t)base_ptr] = desc.shm_path;
}
}
desc.addr = (uint64_t)base_ptr;
desc.length = alloc_size; |
||
| } else if (location.type() == "cpu" || | ||
| location.type() == kWildcardLocation) { | ||
| if (host_register_) | ||
|
|
@@ -232,11 +267,33 @@ Status NVLinkTransport::addMemoryBuffer(BufferDesc& desc, | |
| } | ||
|
|
||
| Status NVLinkTransport::removeMemoryBuffer(BufferDesc& desc) { | ||
| desc.shm_path.clear(); | ||
| LocationParser location(desc.location); | ||
| if (location.type() == "cpu" && host_register_) { | ||
| if (location.type() == "cuda") { | ||
| // Resolve base the same way we did in addMemoryBuffer, so we | ||
| // remove the right entry even for sub-allocated addresses. | ||
| CUdeviceptr base_ptr = 0; | ||
| size_t alloc_size = 0; | ||
| CUresult cu_err = | ||
| cuMemGetAddressRange(&base_ptr, &alloc_size, (CUdeviceptr)desc.addr); | ||
|
|
||
| uint64_t key = desc.addr; | ||
| if (cu_err == CUDA_SUCCESS) { | ||
| key = (uint64_t)base_ptr; | ||
| } else { | ||
| LOG(WARNING) << "NVLinkTransport: cuMemGetAddressRange failed for " | ||
| << "addr 0x" << std::hex << desc.addr << std::dec | ||
| << " during removal (error " << cu_err | ||
| << "). Memory may already be freed."; | ||
| } | ||
|
|
||
| { | ||
| std::lock_guard<std::mutex> lock(register_mutex_); | ||
| registered_base_addrs_.erase(key); | ||
| } | ||
| } else if (location.type() == "cpu" && host_register_) { | ||
| CHECK_CUDA(cudaHostUnregister((void*)desc.addr)); | ||
| } | ||
| desc.shm_path.clear(); | ||
| return Status::OK(); | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use
std::unordered_mapinstead ofstd::unordered_setto store the serialized IPC handles for reuse across multiple sub-allocations within the same memory segment.