diff --git a/mooncake-transfer-engine/tent/include/tent/transport/nvlink/nvlink_transport.h b/mooncake-transfer-engine/tent/include/tent/transport/nvlink/nvlink_transport.h index 2071f3ace4..b4b8da0ecf 100644 --- a/mooncake-transfer-engine/tent/include/tent/transport/nvlink/nvlink_transport.h +++ b/mooncake-transfer-engine/tent/include/tent/transport/nvlink/nvlink_transport.h @@ -17,8 +17,10 @@ #include #include +#include #include #include +#include #include #include @@ -109,6 +111,9 @@ class NVLinkTransport : public Transport { std::string machine_id_; uint64_t async_memcpy_threshold_; bool host_register_; + + std::mutex register_mutex_; + std::unordered_set registered_base_addrs_; }; } // namespace tent } // namespace mooncake diff --git a/mooncake-transfer-engine/tent/src/transport/nvlink/nvlink_transport.cpp b/mooncake-transfer-engine/tent/src/transport/nvlink/nvlink_transport.cpp index dc41e0f28d..d87026285e 100644 --- a/mooncake-transfer-engine/tent/src/transport/nvlink/nvlink_transport.cpp +++ b/mooncake-transfer-engine/tent/src/transport/nvlink/nvlink_transport.cpp @@ -215,10 +215,45 @@ Status NVLinkTransport::addMemoryBuffer(BufferDesc& desc, // If the memory region is allocated using cuMemAlloc, // we cannot use cudaIpcGetMemHandle, so skip it if (options.type == MNNVL) return Status::OK(); + + // Resolve the true cudaMalloc base address. Caching allocators + // (e.g. PyTorch) sub-allocate tensors within larger cudaMalloc + // segments. cudaIpcGetMemHandle returns a handle for the whole + // segment, so we need to register at segment granularity. + CUdeviceptr base_ptr = 0; + size_t alloc_size = 0; + CUresult cu_err = cuMemGetAddressRange(&base_ptr, &alloc_size, + (CUdeviceptr)desc.addr); + if (cu_err != CUDA_SUCCESS) { + LOG(ERROR) << "NVLinkTransport: cuMemGetAddressRange failed for " + << "addr 0x" << std::hex << desc.addr << std::dec + << " (error " << cu_err << ")"; + return Status::InternalError( + "cuMemGetAddressRange failed" LOC_MARK); + } + + { + std::lock_guard lock(register_mutex_); + if (registered_base_addrs_.count((uint64_t)base_ptr)) { + // Already registered this cudaMalloc block, just tag transport + desc.addr = (uint64_t)base_ptr; + desc.length = alloc_size; + desc.transports.push_back(TransportType::NVLINK); + return Status::OK(); + } + } + cudaIpcMemHandle_t handle; - CHECK_CUDA(cudaIpcGetMemHandle(&handle, (void*)desc.addr)); + CHECK_CUDA(cudaIpcGetMemHandle(&handle, (void*)base_ptr)); + desc.addr = (uint64_t)base_ptr; + desc.length = alloc_size; desc.shm_path = serializeBinaryData(&handle, sizeof(cudaIpcMemHandle_t)); + + { + std::lock_guard lock(register_mutex_); + registered_base_addrs_.insert((uint64_t)base_ptr); + } } else if (location.type() == "cpu" || location.type() == kWildcardLocation) { if (host_register_) @@ -232,11 +267,33 @@ Status NVLinkTransport::addMemoryBuffer(BufferDesc& desc, } Status NVLinkTransport::removeMemoryBuffer(BufferDesc& desc) { - desc.shm_path.clear(); LocationParser location(desc.location); - if (location.type() == "cpu" && host_register_) { + if (location.type() == "cuda") { + // Resolve base the same way we did in addMemoryBuffer, so we + // remove the right entry even for sub-allocated addresses. + CUdeviceptr base_ptr = 0; + size_t alloc_size = 0; + CUresult cu_err = cuMemGetAddressRange(&base_ptr, &alloc_size, + (CUdeviceptr)desc.addr); + + uint64_t key = desc.addr; + if (cu_err == CUDA_SUCCESS) { + key = (uint64_t)base_ptr; + } else { + LOG(WARNING) << "NVLinkTransport: cuMemGetAddressRange failed for " + << "addr 0x" << std::hex << desc.addr << std::dec + << " during removal (error " << cu_err + << "). Memory may already be freed."; + } + + { + std::lock_guard lock(register_mutex_); + registered_base_addrs_.erase(key); + } + } else if (location.type() == "cpu" && host_register_) { CHECK_CUDA(cudaHostUnregister((void*)desc.addr)); } + desc.shm_path.clear(); return Status::OK(); }