From 8d56464515ed642e2284b60e6afda7c1b1594a86 Mon Sep 17 00:00:00 2001 From: doujiang24 Date: Sat, 12 Apr 2025 19:16:44 +0800 Subject: [PATCH 1/4] draft: p2phandshake: init Signed-off-by: doujiang24 --- .../transfer_engine/transfer_engine_py.cpp | 37 ++-- .../example/http-metadata-server/go.mod | 6 +- mooncake-transfer-engine/include/common.h | 49 ++++- .../include/transfer_metadata.h | 11 +- .../include/transfer_metadata_plugin.h | 21 ++- .../transport/rdma_transport/rdma_transport.h | 2 +- .../src/transfer_engine.cpp | 4 + .../src/transfer_metadata.cpp | 150 ++++++++++++--- .../src/transfer_metadata_plugin.cpp | 171 ++++++++++++++++-- .../rdma_transport/rdma_transport.cpp | 37 ++-- .../transport/tcp_transport/tcp_transport.cpp | 14 +- 11 files changed, 401 insertions(+), 101 deletions(-) diff --git a/mooncake-integration/transfer_engine/transfer_engine_py.cpp b/mooncake-integration/transfer_engine/transfer_engine_py.cpp index f3a6bf43c..631499759 100644 --- a/mooncake-integration/transfer_engine/transfer_engine_py.cpp +++ b/mooncake-integration/transfer_engine/transfer_engine_py.cpp @@ -71,17 +71,19 @@ std::pair parseConnectionString( } int TransferEnginePy::initialize(const char *local_hostname, - const char *metadata_server, const char *protocol, - const char *device_name) { + const char *metadata_server, + const char *protocol, + const char *device_name) { auto conn_string = parseConnectionString(metadata_server); return initializeExt(local_hostname, conn_string.second.c_str(), protocol, device_name, conn_string.first.c_str()); } int TransferEnginePy::initializeExt(const char *local_hostname, - const char *metadata_server, - const char *protocol, const char *device_name, - const char *metadata_type) { + const char *metadata_server, + const char *protocol, + const char *device_name, + const char *metadata_type) { std::string conn_string = metadata_server; if (conn_string.find("://") == std::string::npos) conn_string = @@ -104,9 +106,10 @@ int TransferEnginePy::initializeExt(const char *local_hostname, xport_ = nullptr; if (strcmp(protocol, "rdma") == 0) { auto device_names = formatDeviceNames(device_name); - std::string nic_priority_matrix = - "{\"cpu:0\": [[" + device_names + "], []]," - "\"cuda:0\": [[" + device_names + "], []]}"; + std::string nic_priority_matrix = "{\"cpu:0\": [[" + device_names + + "], []]," + "\"cuda:0\": [[" + + device_names + "], []]}"; void **args = (void **)malloc(2 * sizeof(void *)); args[0] = (void *)nic_priority_matrix.c_str(); args[1] = nullptr; @@ -193,8 +196,10 @@ int TransferEnginePy::freeManagedBuffer(uintptr_t buffer_addr, size_t length) { return 0; } -int TransferEnginePy::transferSyncWrite(const char *target_hostname, uintptr_t buffer, - uintptr_t peer_buffer_address, size_t length) { +int TransferEnginePy::transferSyncWrite(const char *target_hostname, + uintptr_t buffer, + uintptr_t peer_buffer_address, + size_t length) { Transport::SegmentHandle handle; if (handle_map_.count(target_hostname)) { handle = handle_map_[target_hostname]; @@ -229,8 +234,10 @@ int TransferEnginePy::transferSyncWrite(const char *target_hostname, uintptr_t b } } -int TransferEnginePy::transferSyncRead(const char *target_hostname, uintptr_t buffer, - uintptr_t peer_buffer_address, size_t length) { +int TransferEnginePy::transferSyncRead(const char *target_hostname, + uintptr_t buffer, + uintptr_t peer_buffer_address, + size_t length) { Transport::SegmentHandle handle; if (handle_map_.count(target_hostname)) { handle = handle_map_[target_hostname]; @@ -348,13 +355,15 @@ PYBIND11_MODULE(engine, m) { .def(py::init<>()) .def("initialize", &TransferEnginePy::initialize) .def("initialize_ext", &TransferEnginePy::initializeExt) - .def("allocate_managed_buffer", &TransferEnginePy::allocateManagedBuffer) + .def("allocate_managed_buffer", + &TransferEnginePy::allocateManagedBuffer) .def("free_managed_buffer", &TransferEnginePy::freeManagedBuffer) .def("transfer_sync_write", &TransferEnginePy::transferSyncWrite) .def("transfer_sync_read", &TransferEnginePy::transferSyncRead) .def("transfer_sync", &TransferEnginePy::transferSync) .def("write_bytes_to_buffer", &TransferEnginePy::writeBytesToBuffer) - .def("read_bytes_from_buffer", &TransferEnginePy::readBytesFromBuffer) + .def("read_bytes_from_buffer", + &TransferEnginePy::readBytesFromBuffer) .def("register_memory", &TransferEnginePy::registerMemory) .def("unregister_memory", &TransferEnginePy::unregisterMemory) .def("get_first_buffer_address", diff --git a/mooncake-transfer-engine/example/http-metadata-server/go.mod b/mooncake-transfer-engine/example/http-metadata-server/go.mod index 05e2cd401..6cf7ad68c 100644 --- a/mooncake-transfer-engine/example/http-metadata-server/go.mod +++ b/mooncake-transfer-engine/example/http-metadata-server/go.mod @@ -1,6 +1,8 @@ module github.com/kvcache-ai/Mooncake/mooncake-transfer-engine/example/http-metadata-server -go 1.22.9 +go 1.23.0 + +toolchain go1.23.8 require github.com/gin-gonic/gin v1.10.0 @@ -25,7 +27,7 @@ require ( github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect golang.org/x/arch v0.8.0 // indirect - golang.org/x/crypto v0.31.0 // indirect + golang.org/x/crypto v0.35.0 // indirect golang.org/x/net v0.36.0 // indirect golang.org/x/sys v0.30.0 // indirect golang.org/x/text v0.22.0 // indirect diff --git a/mooncake-transfer-engine/include/common.h b/mooncake-transfer-engine/include/common.h index b717f652e..30a4f9adc 100644 --- a/mooncake-transfer-engine/include/common.h +++ b/mooncake-transfer-engine/include/common.h @@ -41,6 +41,11 @@ namespace mooncake { const static int LOCAL_SEGMENT_ID = 0; +enum class HandShakeRequestType { + Connection = 0, + Metadata = 1, +}; + static inline int bindToSocket(int socket_id) { if (unlikely(numa_available() < 0)) { LOG(ERROR) << "The platform does not support NUMA"; @@ -140,27 +145,53 @@ static inline ssize_t readFully(int fd, void *buf, size_t len) { return len; } -static inline int writeString(int fd, const std::string &str) { - uint64_t length = str.size(); +static inline int writeString(int fd, const HandShakeRequestType type, + const std::string &str) { + uint8_t byte = static_cast(type); + LOG(INFO) << "writeString: type " << (int)byte << ", str(" << str.size() + << "): " << str; + uint64_t length = str.size() + sizeof(byte); if (writeFully(fd, &length, sizeof(length)) != (ssize_t)sizeof(length)) return ERR_SOCKET; + if (writeFully(fd, &byte, sizeof(byte)) != (ssize_t)sizeof(byte)) + return ERR_SOCKET; if (writeFully(fd, str.data(), length) != (ssize_t)length) return ERR_SOCKET; return 0; } -static inline std::string readString(int fd) { +static inline std::pair readString(int fd) { + HandShakeRequestType type = HandShakeRequestType::Connection; + const static size_t kMaxLength = 1ull << 20; uint64_t length = 0; - if (readFully(fd, &length, sizeof(length)) != (ssize_t)sizeof(length)) - return ""; - if (length > kMaxLength) return ""; + ssize_t n = readFully(fd, &length, sizeof(length)); + if (n != (ssize_t)sizeof(length)) { + LOG(ERROR) << "Failed to read length, got: " << n; + return {type, ""}; + } + + if (length > kMaxLength) { + LOG(ERROR) << "Read too large length from socket: " << length; + return {type, ""}; + } + std::string str; std::vector buffer(length); - if (readFully(fd, buffer.data(), length) != (ssize_t)length) return ""; + n = readFully(fd, buffer.data(), length); + if (n != (ssize_t)length) { + LOG(ERROR) << "Failed to read string, got: " << n; + return {type, ""}; + } + + if (buffer[0] <= static_cast(HandShakeRequestType::Metadata)) { + type = static_cast(buffer[0]); + str.assign(buffer.data() + 1, length - 1); + } else { + str.assign(buffer.data(), length); + } - str.assign(buffer.data(), length); - return str; + return {type, str}; } const static std::string NIC_PATH_DELIM = "@"; diff --git a/mooncake-transfer-engine/include/transfer_metadata.h b/mooncake-transfer-engine/include/transfer_metadata.h index 1ddb4c897..b267ac848 100644 --- a/mooncake-transfer-engine/include/transfer_metadata.h +++ b/mooncake-transfer-engine/include/transfer_metadata.h @@ -34,6 +34,8 @@ namespace mooncake { struct MetadataStoragePlugin; struct HandShakePlugin; +#define P2PHANDSHAKE "P2PHANDSHAKE" + class TransferMetadata { public: struct DeviceDesc { @@ -73,7 +75,7 @@ class TransferMetadata { struct RpcMetaDesc { std::string ip_or_host_name; uint16_t rpc_port; - int sockfd; // local cache + int sockfd; // local cache }; struct HandShakeDesc { @@ -134,6 +136,13 @@ class TransferMetadata { HandShakeDesc &peer_desc); private: + int encodeSegmentDesc(const SegmentDesc &desc, Json::Value &segmentJSON); + std::shared_ptr decodeSegmentDesc( + Json::Value &segmentJSON, const std::string &segment_name); + int receivePeerMetadata(const Json::Value &peer_json, + Json::Value &local_json); + + bool p2p_handshake_mode_{false}; // local cache RWSpinlock segment_lock_; std::unordered_map> diff --git a/mooncake-transfer-engine/include/transfer_metadata_plugin.h b/mooncake-transfer-engine/include/transfer_metadata_plugin.h index 022e4792f..73d250fc1 100644 --- a/mooncake-transfer-engine/include/transfer_metadata_plugin.h +++ b/mooncake-transfer-engine/include/transfer_metadata_plugin.h @@ -40,16 +40,31 @@ struct HandShakePlugin { // When accept a new connection, this function will be called. // The first param represents peer endpoint's attributes, while // the second param represents local endpoint's attributes - using OnReceiveCallBack = + using OnConnectionCallBack = std::function; - virtual int startDaemon(OnReceiveCallBack on_recv_callback, - uint16_t listen_port, int sockfd) = 0; + // When accept a new metadata request. + using OnMetadataCallBack = + std::function; + + virtual int startDaemon(uint16_t listen_port, int sockfd) = 0; // Connect to peer endpoint, and wait for receiving // peer endpoint's attributes virtual int send(std::string ip_or_host_name, uint16_t rpc_port, const Json::Value &local, Json::Value &peer) = 0; + + // Exchange metadata with remote peer. + virtual int exchangeMetadata(std::string ip_or_host_name, uint16_t rpc_port, + const Json::Value &local_metadata, + Json::Value &peer_metadata) = 0; + + // Register callback function for receiving a new connection. + virtual void registerOnConnectionCallBack( + OnConnectionCallBack callback) = 0; + + // Register callback function for receiving metadata exchange request. + virtual void registerOnMetadataCallBack(OnMetadataCallBack callback) = 0; }; std::vector findLocalIpAddresses(); diff --git a/mooncake-transfer-engine/include/transport/rdma_transport/rdma_transport.h b/mooncake-transfer-engine/include/transport/rdma_transport/rdma_transport.h index 4f2bccb37..695629bb2 100644 --- a/mooncake-transfer-engine/include/transport/rdma_transport/rdma_transport.h +++ b/mooncake-transfer-engine/include/transport/rdma_transport/rdma_transport.h @@ -74,7 +74,7 @@ class RdmaTransport : public Transport { // TRANSFER Status submitTransfer(BatchID batch_id, - const std::vector &entries) override; + const std::vector &entries) override; Status submitTransferTask( const std::vector &request_list, diff --git a/mooncake-transfer-engine/src/transfer_engine.cpp b/mooncake-transfer-engine/src/transfer_engine.cpp index 6e70932a5..6751f1cb6 100644 --- a/mooncake-transfer-engine/src/transfer_engine.cpp +++ b/mooncake-transfer-engine/src/transfer_engine.cpp @@ -125,6 +125,10 @@ Transport *TransferEngine::installTransport(const std::string &proto, int TransferEngine::uninstallTransport(const std::string &proto) { return 0; } +// port: env +// ip + +// ip:port Transport::SegmentHandle TransferEngine::openSegment( const std::string &segment_name) { if (segment_name.empty()) return ERR_INVALID_ARGUMENT; diff --git a/mooncake-transfer-engine/src/transfer_metadata.cpp b/mooncake-transfer-engine/src/transfer_metadata.cpp index 6b3355aca..e8c670be2 100644 --- a/mooncake-transfer-engine/src/transfer_metadata.cpp +++ b/mooncake-transfer-engine/src/transfer_metadata.cpp @@ -68,20 +68,29 @@ struct TransferHandshakeUtil { }; TransferMetadata::TransferMetadata(const std::string &conn_string) { + next_segment_id_.store(1); handshake_plugin_ = HandShakePlugin::Create(conn_string); + if (!handshake_plugin_) { + LOG(ERROR) + << "Unable to create metadata handshake plugin with conn string: " + << conn_string; + } + if (conn_string == P2PHANDSHAKE) { + p2p_handshake_mode_ = true; + return; + } storage_plugin_ = MetadataStoragePlugin::Create(conn_string); - if (!handshake_plugin_ || !storage_plugin_) { - LOG(ERROR) << "Unable to create metadata plugins with conn string " - << conn_string; + if (!storage_plugin_) { + LOG(ERROR) + << "Unable to create metadata storage plugin with conn string " + << conn_string; } - next_segment_id_.store(1); } TransferMetadata::~TransferMetadata() { handshake_plugin_.reset(); } -int TransferMetadata::updateSegmentDesc(const std::string &segment_name, - const SegmentDesc &desc) { - Json::Value segmentJSON; +int TransferMetadata::encodeSegmentDesc(const SegmentDesc &desc, + Json::Value &segmentJSON) { segmentJSON["name"] = desc.name; segmentJSON["protocol"] = desc.protocol; @@ -127,6 +136,20 @@ int TransferMetadata::updateSegmentDesc(const std::string &segment_name, << desc.name << " protocol " << desc.protocol; return ERR_METADATA; } + return 0; +} + +int TransferMetadata::updateSegmentDesc(const std::string &segment_name, + const SegmentDesc &desc) { + if (p2p_handshake_mode_) { + return 0; + } + + Json::Value segmentJSON; + int ret = encodeSegmentDesc(desc, segmentJSON); + if (ret) { + return ret; + } if (!storage_plugin_->set(getFullMetadataKey(segment_name), segmentJSON)) { LOG(ERROR) << "Failed to register segment descriptor, name " @@ -138,6 +161,9 @@ int TransferMetadata::updateSegmentDesc(const std::string &segment_name, } int TransferMetadata::removeSegmentDesc(const std::string &segment_name) { + if (p2p_handshake_mode_) { + return 0; + } if (!storage_plugin_->remove(getFullMetadataKey(segment_name))) { LOG(ERROR) << "Failed to unregister segment descriptor, name " << segment_name; @@ -146,15 +172,9 @@ int TransferMetadata::removeSegmentDesc(const std::string &segment_name) { return 0; } -std::shared_ptr TransferMetadata::getSegmentDesc( - const std::string &segment_name) { - Json::Value segmentJSON; - if (!storage_plugin_->get(getFullMetadataKey(segment_name), segmentJSON)) { - LOG(WARNING) << "Failed to retrieve segment descriptor, name " - << segment_name; - return nullptr; - } - +std::shared_ptr +TransferMetadata::decodeSegmentDesc(Json::Value &segmentJSON, + const std::string &segment_name) { auto desc = std::make_shared(); desc->name = segmentJSON["name"].asString(); desc->protocol = segmentJSON["protocol"].asString(); @@ -227,10 +247,47 @@ std::shared_ptr TransferMetadata::getSegmentDesc( << " protocol " << desc->protocol; return nullptr; } - return desc; } +int TransferMetadata::receivePeerMetadata(const Json::Value &peer_json, + Json::Value &local_json) { + // auto peer_desc = decodeSegmentDesc(peer_json, + // peer_json["name"].asString()); + auto local_desc = segment_id_to_desc_map_[LOCAL_SEGMENT_ID]; + int ret = encodeSegmentDesc(*local_desc.get(), local_json); + return ret; +} + +std::shared_ptr TransferMetadata::getSegmentDesc( + const std::string &segment_name) { + Json::Value peer_json; + + if (p2p_handshake_mode_) { + auto [ip, port] = parseHostNameWithPort(segment_name); + Json::Value local_json; + auto desc = segment_id_to_desc_map_[LOCAL_SEGMENT_ID]; + int ret = encodeSegmentDesc(*desc.get(), local_json); + if (ret) { + return nullptr; + } + ret = handshake_plugin_->exchangeMetadata(ip, port, local_json, + peer_json); + if (ret) { + return nullptr; + } + } else { + if (!storage_plugin_->get(getFullMetadataKey(segment_name), + peer_json)) { + LOG(WARNING) << "Failed to retrieve segment descriptor, name " + << segment_name; + return nullptr; + } + } + + return decodeSegmentDesc(peer_json, segment_name); +} + int TransferMetadata::syncSegmentCache(const std::string &segment_name) { RWSpinlock::WriteGuard guard(segment_lock_); for (auto &entry : segment_id_to_desc_map_) { @@ -275,7 +332,8 @@ TransferMetadata::getSegmentDescByName(const std::string &segment_name, std::shared_ptr TransferMetadata::getSegmentDescByID(SegmentID segment_id, bool force_update) { - if (segment_id != LOCAL_SEGMENT_ID && (!globalConfig().metacache || force_update)) { + if (segment_id != LOCAL_SEGMENT_ID && + (!globalConfig().metacache || force_update)) { RWSpinlock::WriteGuard guard(segment_lock_); if (!segment_id_to_desc_map_.count(segment_id)) return nullptr; auto segment_desc = @@ -365,6 +423,21 @@ int TransferMetadata::removeLocalMemoryBuffer(void *addr, int TransferMetadata::addRpcMetaEntry(const std::string &server_name, RpcMetaDesc &desc) { + local_rpc_meta_ = desc; + + if (p2p_handshake_mode_) { + int rc = handshake_plugin_->startDaemon(desc.rpc_port, desc.sockfd); + if (rc != 0) { + return rc; + } + handshake_plugin_->registerOnMetadataCallBack( + [this](const Json::Value &peer, Json::Value &local) -> int { + return receivePeerMetadata(peer, local); + }); + + return 0; + } + Json::Value rpcMetaJSON; rpcMetaJSON["ip_or_host_name"] = desc.ip_or_host_name; rpcMetaJSON["rpc_port"] = static_cast(desc.rpc_port); @@ -372,11 +445,13 @@ int TransferMetadata::addRpcMetaEntry(const std::string &server_name, LOG(ERROR) << "Failed to set location of " << server_name; return ERR_METADATA; } - local_rpc_meta_ = desc; return 0; } int TransferMetadata::removeRpcMetaEntry(const std::string &server_name) { + if (p2p_handshake_mode_) { + return 0; + } if (!storage_plugin_->remove(kRpcMetaPrefix + server_name)) { LOG(ERROR) << "Failed to remove location of " << server_name; return ERR_METADATA; @@ -394,20 +469,31 @@ int TransferMetadata::getRpcMetaEntry(const std::string &server_name, } } RWSpinlock::WriteGuard guard(rpc_meta_lock_); - Json::Value rpcMetaJSON; - if (!storage_plugin_->get(kRpcMetaPrefix + server_name, rpcMetaJSON)) { - LOG(ERROR) << "Failed to find location of " << server_name; - return ERR_METADATA; + if (p2p_handshake_mode_) { + auto [ip, port] = parseHostNameWithPort(server_name); + desc.ip_or_host_name = ip; + desc.rpc_port = port; + } else { + Json::Value rpcMetaJSON; + if (!storage_plugin_->get(kRpcMetaPrefix + server_name, rpcMetaJSON)) { + LOG(ERROR) << "Failed to find location of " << server_name; + return ERR_METADATA; + } + desc.ip_or_host_name = rpcMetaJSON["ip_or_host_name"].asString(); + desc.rpc_port = (uint16_t)rpcMetaJSON["rpc_port"].asUInt(); } - desc.ip_or_host_name = rpcMetaJSON["ip_or_host_name"].asString(); - desc.rpc_port = (uint16_t)rpcMetaJSON["rpc_port"].asUInt(); rpc_meta_map_[server_name] = desc; return 0; } int TransferMetadata::startHandshakeDaemon( OnReceiveHandShake on_receive_handshake, uint16_t listen_port, int sockfd) { - return handshake_plugin_->startDaemon( + int rc = handshake_plugin_->startDaemon(listen_port, sockfd); + if (rc != 0) { + return rc; + } + + handshake_plugin_->registerOnConnectionCallBack( [on_receive_handshake](const Json::Value &peer, Json::Value &local) -> int { HandShakeDesc local_desc, peer_desc; @@ -416,16 +502,20 @@ int TransferMetadata::startHandshakeDaemon( if (ret) return ret; local = TransferHandshakeUtil::encode(local_desc); return 0; - }, - listen_port, sockfd); + }); + + return 0; } int TransferMetadata::sendHandshake(const std::string &peer_server_name, const HandShakeDesc &local_desc, HandShakeDesc &peer_desc) { RpcMetaDesc peer_location; - if (getRpcMetaEntry(peer_server_name, peer_location)) { - return ERR_METADATA; + if (p2p_handshake_mode_) { + } else { + if (getRpcMetaEntry(peer_server_name, peer_location)) { + return ERR_METADATA; + } } auto local = TransferHandshakeUtil::encode(local_desc); Json::Value peer; diff --git a/mooncake-transfer-engine/src/transfer_metadata_plugin.cpp b/mooncake-transfer-engine/src/transfer_metadata_plugin.cpp index 1baf8139f..cdf73d08c 100644 --- a/mooncake-transfer-engine/src/transfer_metadata_plugin.cpp +++ b/mooncake-transfer-engine/src/transfer_metadata_plugin.cpp @@ -448,7 +448,8 @@ std::shared_ptr MetadataStoragePlugin::Create( #endif // USE_HTTP LOG(FATAL) << "Unable to find metadata storage plugin " - << parsed_conn_string.first; + << parsed_conn_string.first + << " with conn string: " << conn_string; return nullptr; } @@ -491,8 +492,20 @@ struct SocketHandShakePlugin : public HandShakePlugin { } } - virtual int startDaemon(OnReceiveCallBack on_recv_callback, - uint16_t listen_port, int sockfd) { + virtual void registerOnConnectionCallBack(OnConnectionCallBack callback) { + on_connection_callback_ = callback; + } + + virtual void registerOnMetadataCallBack(OnMetadataCallBack callback) { + on_metadata_callback_ = callback; + } + + virtual int startDaemon(uint16_t listen_port, int sockfd) { + if (listener_running_) { + LOG(INFO) << "SocketHandShakePlugin: listener already running"; + return 0; + } + sockaddr_in bind_address; int on = 1; memset(&bind_address, 0, sizeof(sockaddr_in)); @@ -543,7 +556,7 @@ struct SocketHandShakePlugin : public HandShakePlugin { } listener_running_ = true; - listener_ = std::thread([this, on_recv_callback]() { + listener_ = std::thread([this]() { while (listener_running_) { sockaddr_in addr; socklen_t addr_len = sizeof(sockaddr_in); @@ -580,24 +593,51 @@ struct SocketHandShakePlugin : public HandShakePlugin { Json::Value local, peer; Json::Reader reader; - if (!reader.parse(readString(conn_fd), peer)) { + + auto [type, json_str] = readString(conn_fd); + if (!reader.parse(json_str, peer)) { LOG(ERROR) << "SocketHandShakePlugin: failed to receive " - "handshake message: " - "malformed json format, check tcp connection"; + "handshake message, " + "malformed json format:" + << reader.getFormattedErrorMessages() + << ", json string length: " << json_str.size() + << ", json string content: " << json_str; + close(conn_fd); + continue; + } + + if (type == HandShakeRequestType::Connection) { + on_connection_callback_(peer, local); + } else if (type == HandShakeRequestType::Metadata) { + on_metadata_callback_(peer, local); + } else { + LOG(ERROR) << "SocketHandShakePlugin: unexpected handshake " + "message type"; close(conn_fd); continue; } - on_recv_callback(peer, local); - int ret = writeString(conn_fd, Json::FastWriter{}.write(local)); + int ret = + writeString(conn_fd, type, Json::FastWriter{}.write(local)); + LOG(INFO) << "writeString return: " << ret; if (ret) { LOG(ERROR) << "SocketHandShakePlugin: failed to send " - "handshake message: " + "message: " "malformed json format, check tcp connection"; close(conn_fd); continue; } + ret = shutdown(conn_fd, SHUT_WR); + if (ret) { + PLOG(ERROR) << "SocketHandShakePlugin: shutdown() failed, " + "connection may be incomplete"; + close(conn_fd); + continue; + } + + // TODO: wait for the peer to close the connection + close(conn_fd); } return; @@ -641,15 +681,13 @@ struct SocketHandShakePlugin : public HandShakePlugin { return ret; } - int doSend(struct addrinfo *addr, const Json::Value &local, - Json::Value &peer) { + int doConnect(struct addrinfo *addr, int &conn_fd) { if (globalConfig().verbose) LOG(INFO) << "SocketHandShakePlugin: connecting " << getNetworkAddress(addr->ai_addr); int on = 1; - int conn_fd = - socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol); + conn_fd = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol); if (conn_fd == -1) { PLOG(ERROR) << "SocketHandShakePlugin: socket()"; return ERR_SOCKET; @@ -677,7 +715,19 @@ struct SocketHandShakePlugin : public HandShakePlugin { return ERR_SOCKET; } - int ret = writeString(conn_fd, Json::FastWriter{}.write(local)); + return 0; + } + + int doSend(struct addrinfo *addr, const Json::Value &local, + Json::Value &peer) { + int conn_fd = -1; + int ret = doConnect(addr, conn_fd); + if (ret) { + return ret; + } + + ret = writeString(conn_fd, HandShakeRequestType::Connection, + Json::FastWriter{}.write(local)); if (ret) { LOG(ERROR) << "SocketHandShakePlugin: failed to send handshake message: " @@ -687,7 +737,15 @@ struct SocketHandShakePlugin : public HandShakePlugin { } Json::Reader reader; - if (!reader.parse(readString(conn_fd), peer)) { + auto [type, json_str] = readString(conn_fd); + if (type != HandShakeRequestType::Connection) { + LOG(ERROR) + << "SocketHandShakePlugin: unexpected handshake message type"; + close(conn_fd); + return ERR_SOCKET; + } + + if (!reader.parse(json_str, peer)) { LOG(ERROR) << "SocketHandShakePlugin: failed to receive handshake " "message: " "malformed json format, check tcp connection"; @@ -699,9 +757,90 @@ struct SocketHandShakePlugin : public HandShakePlugin { return 0; } + virtual int exchangeMetadata(std::string ip_or_host_name, uint16_t rpc_port, + const Json::Value &local_metadata, + Json::Value &peer_metadata) { + struct addrinfo hints; + struct addrinfo *result, *rp; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_INET; + hints.ai_socktype = SOCK_STREAM; + + char service[16]; + sprintf(service, "%u", rpc_port); + if (getaddrinfo(ip_or_host_name.c_str(), service, &hints, &result)) { + PLOG(ERROR) + << "SocketHandShakePlugin: failed to get IP address of peer " + "server " + << ip_or_host_name << ":" << rpc_port + << ", check DNS and /etc/hosts, or use IPv4 address instead"; + return ERR_DNS; + } + + int ret = 0; + for (rp = result; rp; rp = rp->ai_next) { + ret = doSendMetadta(rp, local_metadata, peer_metadata); + if (ret == 0) { + freeaddrinfo(result); + return 0; + } + if (ret == ERR_MALFORMED_JSON) { + return ret; + } + } + + freeaddrinfo(result); + return ret; + } + + int doSendMetadta(struct addrinfo *addr, const Json::Value &local_metadata, + Json::Value &peer_metadata) { + int conn_fd = -1; + int ret = doConnect(addr, conn_fd); + if (ret) { + return ret; + } + + ret = writeString(conn_fd, HandShakeRequestType::Metadata, + Json::FastWriter{}.write(local_metadata)); + if (ret) { + LOG(ERROR) + << "SocketHandShakePlugin: failed to send metadata message: " + "malformed json format, check tcp connection"; + close(conn_fd); + return ret; + } + + Json::Reader reader; + auto [type, json_str] = readString(conn_fd); + if (type != HandShakeRequestType::Metadata) { + LOG(ERROR) + << "SocketHandShakePlugin: unexpected handshake message type"; + close(conn_fd); + return ERR_SOCKET; + } + + LOG(INFO) << "SocketHandShakePlugin: received metadata message: " + << json_str; + + if (!reader.parse(json_str, peer_metadata)) { + LOG(ERROR) << "SocketHandShakePlugin: failed to receive metadata " + "message, malformed json format: " + << reader.getFormattedErrorMessages(); + close(conn_fd); + return ERR_MALFORMED_JSON; + } + + close(conn_fd); + return 0; + } + std::atomic listener_running_; std::thread listener_; int listen_fd_; + + OnConnectionCallBack on_connection_callback_; + OnMetadataCallBack on_metadata_callback_; }; std::shared_ptr HandShakePlugin::Create( diff --git a/mooncake-transfer-engine/src/transport/rdma_transport/rdma_transport.cpp b/mooncake-transfer-engine/src/transport/rdma_transport/rdma_transport.cpp index ddb4dc0d3..a4da14f3e 100644 --- a/mooncake-transfer-engine/src/transport/rdma_transport/rdma_transport.cpp +++ b/mooncake-transfer-engine/src/transport/rdma_transport/rdma_transport.cpp @@ -193,8 +193,8 @@ int RdmaTransport::unregisterLocalMemoryBatch( return metadata_->updateLocalSegmentDesc(); } -Status RdmaTransport::submitTransfer(BatchID batch_id, - const std::vector &entries) { +Status RdmaTransport::submitTransfer( + BatchID batch_id, const std::vector &entries) { auto &batch_desc = *((BatchDesc *)(batch_id)); if (batch_desc.task_list.size() + entries.size() > batch_desc.batch_size) { LOG(ERROR) << "RdmaTransport: Exceed the limitation of current batch's " @@ -241,22 +241,22 @@ Status RdmaTransport::submitTransfer(BatchID batch_id, local_segment_desc->buffers[buffer_id].lkey[device_id]; slices_to_post[context].push_back(slice); task.total_bytes += slice->length; - __sync_fetch_and_add(&task.slice_count, 1);; + __sync_fetch_and_add(&task.slice_count, 1); + ; break; } if (device_id < 0) { auto source_addr = slice->source_addr; delete slice; for (auto &entry : slices_to_post) - for (auto s : entry.second) - delete s; + for (auto s : entry.second) delete s; LOG(ERROR) << "RdmaTransport: Address not registered by any device(s) " << source_addr; return Status::AddressNotRegistered( - "RdmaTransport: not registered by any device(s), address: " - + std::to_string( - reinterpret_cast(source_addr))); + "RdmaTransport: not registered by any device(s), " + "address: " + + std::to_string(reinterpret_cast(source_addr))); } } } @@ -303,22 +303,22 @@ Status RdmaTransport::submitTransferTask( slices_to_post[context].push_back(slice); task.total_bytes += slice->length; // task.slices.push_back(slice); - __sync_fetch_and_add(&task.slice_count, 1);; + __sync_fetch_and_add(&task.slice_count, 1); + ; break; } if (device_id < 0) { auto source_addr = slice->source_addr; delete slice; for (auto &entry : slices_to_post) - for (auto s : entry.second) - delete s; + for (auto s : entry.second) delete s; LOG(ERROR) << "RdmaTransport: Address not registered by any device(s) " << source_addr; return Status::AddressNotRegistered( - "RdmaTransport: not registered by any device(s), address: " - + std::to_string( - reinterpret_cast(source_addr))); + "RdmaTransport: not registered by any device(s), " + "address: " + + std::to_string(reinterpret_cast(source_addr))); } } } @@ -337,8 +337,7 @@ Status RdmaTransport::getTransferStatus(BatchID batch_id, status[task_id].transferred_bytes = task.transferred_bytes; uint64_t success_slice_count = task.success_slice_count; uint64_t failed_slice_count = task.failed_slice_count; - if (success_slice_count + failed_slice_count == - task.slice_count) { + if (success_slice_count + failed_slice_count == task.slice_count) { if (failed_slice_count) status[task_id].s = TransferStatusEnum::FAILED; else @@ -364,8 +363,7 @@ Status RdmaTransport::getTransferStatus(BatchID batch_id, size_t task_id, status.transferred_bytes = task.transferred_bytes; uint64_t success_slice_count = task.success_slice_count; uint64_t failed_slice_count = task.failed_slice_count; - if (success_slice_count + failed_slice_count == - task.slice_count) { + if (success_slice_count + failed_slice_count == task.slice_count) { if (failed_slice_count) status.s = TransferStatusEnum::FAILED; else @@ -434,8 +432,7 @@ int RdmaTransport::startHandshakeDaemon(std::string &local_server_name) { return metadata_->startHandshakeDaemon( std::bind(&RdmaTransport::onSetupRdmaConnections, this, std::placeholders::_1, std::placeholders::_2), - metadata_->localRpcMeta().rpc_port, - metadata_->localRpcMeta().sockfd); + metadata_->localRpcMeta().rpc_port, metadata_->localRpcMeta().sockfd); } // According to the request desc, offset and length information, find proper diff --git a/mooncake-transfer-engine/src/transport/tcp_transport/tcp_transport.cpp b/mooncake-transfer-engine/src/transport/tcp_transport/tcp_transport.cpp index ce649c34e..9d003a31f 100644 --- a/mooncake-transfer-engine/src/transport/tcp_transport/tcp_transport.cpp +++ b/mooncake-transfer-engine/src/transport/tcp_transport/tcp_transport.cpp @@ -221,9 +221,11 @@ int TcpTransport::install(std::string &local_server_name, } if (meta->localRpcMeta().sockfd) { - close(meta->localRpcMeta().rpc_port); + close(meta->localRpcMeta().sockfd); } - context_ = new TcpContext(meta->localRpcMeta().rpc_port); + int tcp_port = meta->localRpcMeta().rpc_port + 1; + LOG(INFO) << "TcpTransport: listen on port " << tcp_port; + context_ = new TcpContext(tcp_port); running_ = true; thread_ = std::thread(&TcpTransport::worker, this); return 0; @@ -322,7 +324,8 @@ Status TcpTransport::submitTransfer( slice->target_id = request.target_id; slice->status = Slice::PENDING; task.slice_list.push_back(slice); - __sync_fetch_and_add(&task.slice_count, 1);; + __sync_fetch_and_add(&task.slice_count, 1); + ; startTransfer(slice); } @@ -345,7 +348,8 @@ Status TcpTransport::submitTransferTask( slice->target_id = request.target_id; slice->status = Slice::PENDING; task.slice_list.push_back(slice); - __sync_fetch_and_add(&task.slice_count, 1);; + __sync_fetch_and_add(&task.slice_count, 1); + ; startTransfer(slice); } return Status::OK(); @@ -379,7 +383,7 @@ void TcpTransport::startTransfer(Slice *slice) { } auto endpoint_iterator = resolver.resolve( boost::asio::ip::tcp::v4(), meta_entry.ip_or_host_name, - std::to_string(meta_entry.rpc_port)); + std::to_string(meta_entry.rpc_port + 1)); boost::asio::connect(socket, endpoint_iterator); auto session = std::make_shared(std::move(socket)); session->on_finalize_ = [slice](TransferStatusEnum status) { From 5011dd823507a51b2948c37723db551e38a7875e Mon Sep 17 00:00:00 2001 From: doujiang24 Date: Mon, 14 Apr 2025 14:03:19 +0800 Subject: [PATCH 2/4] compile ok Signed-off-by: doujiang24 --- mooncake-transfer-engine/include/common.h | 23 +++++++++++++------ .../include/transfer_metadata_plugin.h | 17 +++++--------- .../src/transfer_engine.cpp | 4 ---- .../src/transfer_metadata.cpp | 8 +++---- .../src/transfer_metadata_plugin.cpp | 19 +++++++-------- .../transport/tcp_transport/tcp_transport.cpp | 2 -- 6 files changed, 35 insertions(+), 38 deletions(-) diff --git a/mooncake-transfer-engine/include/common.h b/mooncake-transfer-engine/include/common.h index 30a4f9adc..ce1d2305d 100644 --- a/mooncake-transfer-engine/include/common.h +++ b/mooncake-transfer-engine/include/common.h @@ -44,6 +44,8 @@ const static int LOCAL_SEGMENT_ID = 0; enum class HandShakeRequestType { Connection = 0, Metadata = 1, + // placeholder for old protocol without RequestType + OldProtocol = 0xff, }; static inline int bindToSocket(int socket_id) { @@ -150,11 +152,15 @@ static inline int writeString(int fd, const HandShakeRequestType type, uint8_t byte = static_cast(type); LOG(INFO) << "writeString: type " << (int)byte << ", str(" << str.size() << "): " << str; - uint64_t length = str.size() + sizeof(byte); + uint64_t length = + str.size() + + (type == HandShakeRequestType::OldProtocol ? 0 : sizeof(byte)); if (writeFully(fd, &length, sizeof(length)) != (ssize_t)sizeof(length)) return ERR_SOCKET; - if (writeFully(fd, &byte, sizeof(byte)) != (ssize_t)sizeof(byte)) - return ERR_SOCKET; + if (type != HandShakeRequestType::OldProtocol) { + if (writeFully(fd, &byte, sizeof(byte)) != (ssize_t)sizeof(byte)) + return ERR_SOCKET; + } if (writeFully(fd, str.data(), length) != (ssize_t)length) return ERR_SOCKET; return 0; @@ -167,12 +173,12 @@ static inline std::pair readString(int fd) { uint64_t length = 0; ssize_t n = readFully(fd, &length, sizeof(length)); if (n != (ssize_t)sizeof(length)) { - LOG(ERROR) << "Failed to read length, got: " << n; + LOG(ERROR) << "readString: failed to read length, got: " << n; return {type, ""}; } if (length > kMaxLength) { - LOG(ERROR) << "Read too large length from socket: " << length; + LOG(ERROR) << "readString: too large length from socket: " << length; return {type, ""}; } @@ -180,14 +186,17 @@ static inline std::pair readString(int fd) { std::vector buffer(length); n = readFully(fd, buffer.data(), length); if (n != (ssize_t)length) { - LOG(ERROR) << "Failed to read string, got: " << n; + LOG(ERROR) << "readString: unexpected length, got: " << n + << ", expected: " << length; return {type, ""}; } if (buffer[0] <= static_cast(HandShakeRequestType::Metadata)) { type = static_cast(buffer[0]); - str.assign(buffer.data() + 1, length - 1); + str.assign(buffer.data() + sizeof(char), length - sizeof(char)); } else { + type = HandShakeRequestType::OldProtocol; + // Old protocol, no type str.assign(buffer.data(), length); } diff --git a/mooncake-transfer-engine/include/transfer_metadata_plugin.h b/mooncake-transfer-engine/include/transfer_metadata_plugin.h index 73d250fc1..337a4df02 100644 --- a/mooncake-transfer-engine/include/transfer_metadata_plugin.h +++ b/mooncake-transfer-engine/include/transfer_metadata_plugin.h @@ -37,14 +37,10 @@ struct HandShakePlugin { HandShakePlugin() {} virtual ~HandShakePlugin() {} - // When accept a new connection, this function will be called. - // The first param represents peer endpoint's attributes, while - // the second param represents local endpoint's attributes - using OnConnectionCallBack = - std::function; - - // When accept a new metadata request. - using OnMetadataCallBack = + // When accept a new connection/metadata request, this function will be + // called. The first param represents peer attributes, while the + // second param represents local attributes. + using OnReceiveCallBack = std::function; virtual int startDaemon(uint16_t listen_port, int sockfd) = 0; @@ -60,11 +56,10 @@ struct HandShakePlugin { Json::Value &peer_metadata) = 0; // Register callback function for receiving a new connection. - virtual void registerOnConnectionCallBack( - OnConnectionCallBack callback) = 0; + virtual void registerOnConnectionCallBack(OnReceiveCallBack callback) = 0; // Register callback function for receiving metadata exchange request. - virtual void registerOnMetadataCallBack(OnMetadataCallBack callback) = 0; + virtual void registerOnMetadataCallBack(OnReceiveCallBack callback) = 0; }; std::vector findLocalIpAddresses(); diff --git a/mooncake-transfer-engine/src/transfer_engine.cpp b/mooncake-transfer-engine/src/transfer_engine.cpp index 6751f1cb6..6e70932a5 100644 --- a/mooncake-transfer-engine/src/transfer_engine.cpp +++ b/mooncake-transfer-engine/src/transfer_engine.cpp @@ -125,10 +125,6 @@ Transport *TransferEngine::installTransport(const std::string &proto, int TransferEngine::uninstallTransport(const std::string &proto) { return 0; } -// port: env -// ip - -// ip:port Transport::SegmentHandle TransferEngine::openSegment( const std::string &segment_name) { if (segment_name.empty()) return ERR_INVALID_ARGUMENT; diff --git a/mooncake-transfer-engine/src/transfer_metadata.cpp b/mooncake-transfer-engine/src/transfer_metadata.cpp index e8c670be2..93ea65850 100644 --- a/mooncake-transfer-engine/src/transfer_metadata.cpp +++ b/mooncake-transfer-engine/src/transfer_metadata.cpp @@ -252,6 +252,7 @@ TransferMetadata::decodeSegmentDesc(Json::Value &segmentJSON, int TransferMetadata::receivePeerMetadata(const Json::Value &peer_json, Json::Value &local_json) { + // TODO: save to local cache // auto peer_desc = decodeSegmentDesc(peer_json, // peer_json["name"].asString()); auto local_desc = segment_id_to_desc_map_[LOCAL_SEGMENT_ID]; @@ -511,11 +512,8 @@ int TransferMetadata::sendHandshake(const std::string &peer_server_name, const HandShakeDesc &local_desc, HandShakeDesc &peer_desc) { RpcMetaDesc peer_location; - if (p2p_handshake_mode_) { - } else { - if (getRpcMetaEntry(peer_server_name, peer_location)) { - return ERR_METADATA; - } + if (getRpcMetaEntry(peer_server_name, peer_location)) { + return ERR_METADATA; } auto local = TransferHandshakeUtil::encode(local_desc); Json::Value peer; diff --git a/mooncake-transfer-engine/src/transfer_metadata_plugin.cpp b/mooncake-transfer-engine/src/transfer_metadata_plugin.cpp index cdf73d08c..250240533 100644 --- a/mooncake-transfer-engine/src/transfer_metadata_plugin.cpp +++ b/mooncake-transfer-engine/src/transfer_metadata_plugin.cpp @@ -492,11 +492,11 @@ struct SocketHandShakePlugin : public HandShakePlugin { } } - virtual void registerOnConnectionCallBack(OnConnectionCallBack callback) { + virtual void registerOnConnectionCallBack(OnReceiveCallBack callback) { on_connection_callback_ = callback; } - virtual void registerOnMetadataCallBack(OnMetadataCallBack callback) { + virtual void registerOnMetadataCallBack(OnReceiveCallBack callback) { on_metadata_callback_ = callback; } @@ -606,7 +606,9 @@ struct SocketHandShakePlugin : public HandShakePlugin { continue; } - if (type == HandShakeRequestType::Connection) { + // old protocol equals Connection type + if (type == HandShakeRequestType::Connection || + type == HandShakeRequestType::OldProtocol) { on_connection_callback_(peer, local); } else if (type == HandShakeRequestType::Metadata) { on_metadata_callback_(peer, local); @@ -619,7 +621,6 @@ struct SocketHandShakePlugin : public HandShakePlugin { int ret = writeString(conn_fd, type, Json::FastWriter{}.write(local)); - LOG(INFO) << "writeString return: " << ret; if (ret) { LOG(ERROR) << "SocketHandShakePlugin: failed to send " "message: " @@ -779,7 +780,7 @@ struct SocketHandShakePlugin : public HandShakePlugin { int ret = 0; for (rp = result; rp; rp = rp->ai_next) { - ret = doSendMetadta(rp, local_metadata, peer_metadata); + ret = doSendMetadata(rp, local_metadata, peer_metadata); if (ret == 0) { freeaddrinfo(result); return 0; @@ -793,8 +794,8 @@ struct SocketHandShakePlugin : public HandShakePlugin { return ret; } - int doSendMetadta(struct addrinfo *addr, const Json::Value &local_metadata, - Json::Value &peer_metadata) { + int doSendMetadata(struct addrinfo *addr, const Json::Value &local_metadata, + Json::Value &peer_metadata) { int conn_fd = -1; int ret = doConnect(addr, conn_fd); if (ret) { @@ -839,8 +840,8 @@ struct SocketHandShakePlugin : public HandShakePlugin { std::thread listener_; int listen_fd_; - OnConnectionCallBack on_connection_callback_; - OnMetadataCallBack on_metadata_callback_; + OnReceiveCallBack on_connection_callback_; + OnReceiveCallBack on_metadata_callback_; }; std::shared_ptr HandShakePlugin::Create( diff --git a/mooncake-transfer-engine/src/transport/tcp_transport/tcp_transport.cpp b/mooncake-transfer-engine/src/transport/tcp_transport/tcp_transport.cpp index 9d003a31f..3d943bf66 100644 --- a/mooncake-transfer-engine/src/transport/tcp_transport/tcp_transport.cpp +++ b/mooncake-transfer-engine/src/transport/tcp_transport/tcp_transport.cpp @@ -325,7 +325,6 @@ Status TcpTransport::submitTransfer( slice->status = Slice::PENDING; task.slice_list.push_back(slice); __sync_fetch_and_add(&task.slice_count, 1); - ; startTransfer(slice); } @@ -349,7 +348,6 @@ Status TcpTransport::submitTransferTask( slice->status = Slice::PENDING; task.slice_list.push_back(slice); __sync_fetch_and_add(&task.slice_count, 1); - ; startTransfer(slice); } return Status::OK(); From c542b1865f4e70b2e500905044ce5bd9631ce458 Mon Sep 17 00:00:00 2001 From: doujiang24 Date: Tue, 15 Apr 2025 09:34:55 +0800 Subject: [PATCH 3/4] update binding. Signed-off-by: doujiang24 --- .../transfer_engine/transfer_engine_py.cpp | 34 +++++++++++-------- mooncake-integration/vllm/vllm_adaptor.cpp | 14 ++------ .../src/transfer_engine.cpp | 24 ++++++++----- 3 files changed, 38 insertions(+), 34 deletions(-) diff --git a/mooncake-integration/transfer_engine/transfer_engine_py.cpp b/mooncake-integration/transfer_engine/transfer_engine_py.cpp index 631499759..ed5b1def0 100644 --- a/mooncake-integration/transfer_engine/transfer_engine_py.cpp +++ b/mooncake-integration/transfer_engine/transfer_engine_py.cpp @@ -61,6 +61,9 @@ std::pair parseConnectionString( if (pos != std::string::npos) { proto = conn_string.substr(0, pos); domain = conn_string.substr(pos + 3); + } else if (conn_string == P2PHANDSHAKE) { + proto = ""; + domain = P2PHANDSHAKE; } else { domain = conn_string; } @@ -70,6 +73,18 @@ std::pair parseConnectionString( return result; } +std::string buildConnString(const std::string &metadata_type, + const std::string &metadata_server) { + if (metadata_server == P2PHANDSHAKE) { + return P2PHANDSHAKE; + } + + std::string conn_string = metadata_server; + if (conn_string.find("://") == std::string::npos) + conn_string = metadata_type + "://" + metadata_server; + return conn_string; +} + int TransferEnginePy::initialize(const char *local_hostname, const char *metadata_server, const char *protocol, @@ -84,24 +99,13 @@ int TransferEnginePy::initializeExt(const char *local_hostname, const char *protocol, const char *device_name, const char *metadata_type) { - std::string conn_string = metadata_server; - if (conn_string.find("://") == std::string::npos) - conn_string = - std::string(metadata_type) + "://" + std::string(metadata_server); + std::string conn_string = buildConnString(metadata_type, metadata_server); // TODO: remove `false` in the feature, it's for keep same API in SGLang. engine_ = std::make_unique(false); - if (getenv("MC_LEGACY_RPC_PORT_BINDING")) { - auto hostname_port = parseHostNameWithPort(local_hostname); - int ret = - engine_->init(conn_string, local_hostname, - hostname_port.first.c_str(), hostname_port.second); - if (ret) return -1; - } else { - // the last two params are unused - int ret = engine_->init(conn_string, local_hostname, "", 0); - if (ret) return -1; - } + // the last two params are unused + int ret = engine_->init(conn_string, local_hostname, "", 0); + if (ret) return -1; xport_ = nullptr; if (strcmp(protocol, "rdma") == 0) { diff --git a/mooncake-integration/vllm/vllm_adaptor.cpp b/mooncake-integration/vllm/vllm_adaptor.cpp index 181d53187..7ee9a53be 100644 --- a/mooncake-integration/vllm/vllm_adaptor.cpp +++ b/mooncake-integration/vllm/vllm_adaptor.cpp @@ -86,17 +86,9 @@ int VLLMAdaptor::initializeExt(const char *local_hostname, // TODO: remove `false` in the feature, it's for keep same API in vllm. engine_ = std::make_unique(false); - if (getenv("MC_LEGACY_RPC_PORT_BINDING")) { - auto hostname_port = parseHostNameWithPort(local_hostname); - int ret = - engine_->init(conn_string, local_hostname, - hostname_port.first.c_str(), hostname_port.second); - if (ret) return -1; - } else { - // the last two params are unused - int ret = engine_->init(conn_string, local_hostname, "", 0); - if (ret) return -1; - } + // the last two params are unused + int ret = engine_->init(conn_string, local_hostname, "", 0); + if (ret) return -1; xport_ = nullptr; if (strcmp(protocol, "rdma") == 0) { diff --git a/mooncake-transfer-engine/src/transfer_engine.cpp b/mooncake-transfer-engine/src/transfer_engine.cpp index 6e70932a5..0a4039ffb 100644 --- a/mooncake-transfer-engine/src/transfer_engine.cpp +++ b/mooncake-transfer-engine/src/transfer_engine.cpp @@ -23,15 +23,12 @@ int TransferEngine::init(const std::string &metadata_conn_string, const std::string &local_server_name, const std::string &ip_or_host_name, uint64_t rpc_port) { - local_server_name_ = local_server_name; - metadata_ = std::make_shared(metadata_conn_string); - multi_transports_ = - std::make_shared(metadata_, local_server_name_); - TransferMetadata::RpcMetaDesc desc; - if (getenv("MC_LEGACY_RPC_PORT_BINDING")) { - desc.ip_or_host_name = ip_or_host_name; - desc.rpc_port = rpc_port; + if (getenv("MC_LEGACY_RPC_PORT_BINDING") || + metadata_conn_string == P2PHANDSHAKE) { + auto [host_name, port] = parseHostNameWithPort(local_server_name); + desc.ip_or_host_name = host_name; + desc.rpc_port = port; desc.sockfd = -1; } else { (void)(ip_or_host_name); @@ -62,6 +59,17 @@ int TransferEngine::init(const std::string &metadata_conn_string, << " and port " << desc.rpc_port << " for serving local TCP service"; + if (metadata_conn_string == P2PHANDSHAKE) { + local_server_name_ = + desc.ip_or_host_name + ":" + std::to_string(desc.rpc_port); + } else { + local_server_name_ = local_server_name; + } + + metadata_ = std::make_shared(metadata_conn_string); + multi_transports_ = + std::make_shared(metadata_, local_server_name_); + int ret = metadata_->addRpcMetaEntry(local_server_name_, desc); if (ret) return ret; From 956d92b9306a6d98d4df80eae66a5b6e379d633c Mon Sep 17 00:00:00 2001 From: doujiang24 Date: Tue, 15 Apr 2025 13:43:02 +0800 Subject: [PATCH 4/4] add the get_rpc_port api. Signed-off-by: doujiang24 --- .../transfer_engine/transfer_engine_py.cpp | 3 +++ .../transfer_engine/transfer_engine_py.h | 14 ++++++----- .../include/transfer_engine.h | 4 +++- .../src/transfer_engine.cpp | 24 +++++++++++++------ .../transport/tcp_transport/tcp_transport.cpp | 3 --- 5 files changed, 31 insertions(+), 17 deletions(-) diff --git a/mooncake-integration/transfer_engine/transfer_engine_py.cpp b/mooncake-integration/transfer_engine/transfer_engine_py.cpp index ed5b1def0..503e09c54 100644 --- a/mooncake-integration/transfer_engine/transfer_engine_py.cpp +++ b/mooncake-integration/transfer_engine/transfer_engine_py.cpp @@ -131,6 +131,8 @@ int TransferEnginePy::initializeExt(const char *local_hostname, return 0; } +int TransferEnginePy::getRpcPort() { return engine_->getRpcPort(); } + char *TransferEnginePy::allocateRawBuffer(size_t capacity) { auto buffer = malloc(capacity); if (!buffer) return nullptr; @@ -359,6 +361,7 @@ PYBIND11_MODULE(engine, m) { .def(py::init<>()) .def("initialize", &TransferEnginePy::initialize) .def("initialize_ext", &TransferEnginePy::initializeExt) + .def("get_rpc_port", &TransferEnginePy::getRpcPort) .def("allocate_managed_buffer", &TransferEnginePy::allocateManagedBuffer) .def("free_managed_buffer", &TransferEnginePy::freeManagedBuffer) diff --git a/mooncake-integration/transfer_engine/transfer_engine_py.h b/mooncake-integration/transfer_engine/transfer_engine_py.h index 9312d15ed..26003762d 100644 --- a/mooncake-integration/transfer_engine/transfer_engine_py.h +++ b/mooncake-integration/transfer_engine/transfer_engine_py.h @@ -41,10 +41,8 @@ const static size_t kSlabSizeKB[] = { class TransferEnginePy { public: - enum class TransferOpcode { - READ = 0, - WRITE = 1 - }; + enum class TransferOpcode { READ = 0, WRITE = 1 }; + public: TransferEnginePy(); @@ -57,6 +55,8 @@ class TransferEnginePy { const char *protocol, const char *device_name, const char *metadata_type); + int getRpcPort(); + uintptr_t allocateManagedBuffer(size_t length); int freeManagedBuffer(uintptr_t user_tensor, size_t length); @@ -68,10 +68,11 @@ class TransferEnginePy { uintptr_t peer_buffer_address, size_t length); int transferSync(const char *target_hostname, uintptr_t buffer, - uintptr_t peer_buffer_address, size_t length, TransferOpcode opcode); + uintptr_t peer_buffer_address, size_t length, + TransferOpcode opcode); uintptr_t getFirstBufferAddress(const std::string &segment_name); - + int writeBytesToBuffer(uintptr_t dest_address, char *src_ptr, size_t length) { memcpy((void *)dest_address, (void *)src_ptr, length); @@ -90,6 +91,7 @@ class TransferEnginePy { // must be called before TransferEnginePy::~TransferEnginePy() int unregisterMemory(uintptr_t buffer_addr); + private: char *allocateRawBuffer(size_t capacity); diff --git a/mooncake-transfer-engine/include/transfer_engine.h b/mooncake-transfer-engine/include/transfer_engine.h index e669c635b..5cebc3256 100644 --- a/mooncake-transfer-engine/include/transfer_engine.h +++ b/mooncake-transfer-engine/include/transfer_engine.h @@ -54,7 +54,7 @@ class TransferEngine { int init(const std::string &metadata_conn_string, const std::string &local_server_name, - const std::string &ip_or_host_name = "", + const std::string &ip_or_host_name = "", uint64_t rpc_port = 12345); int freeEngine(); @@ -64,6 +64,8 @@ class TransferEngine { int uninstallTransport(const std::string &proto); + int getRpcPort(); + SegmentHandle openSegment(const std::string &segment_name); int closeSegment(SegmentHandle handle); diff --git a/mooncake-transfer-engine/src/transfer_engine.cpp b/mooncake-transfer-engine/src/transfer_engine.cpp index 0a4039ffb..f605d9d3f 100644 --- a/mooncake-transfer-engine/src/transfer_engine.cpp +++ b/mooncake-transfer-engine/src/transfer_engine.cpp @@ -23,6 +23,7 @@ int TransferEngine::init(const std::string &metadata_conn_string, const std::string &local_server_name, const std::string &ip_or_host_name, uint64_t rpc_port) { + local_server_name_ = local_server_name; TransferMetadata::RpcMetaDesc desc; if (getenv("MC_LEGACY_RPC_PORT_BINDING") || metadata_conn_string == P2PHANDSHAKE) { @@ -30,6 +31,20 @@ int TransferEngine::init(const std::string &metadata_conn_string, desc.ip_or_host_name = host_name; desc.rpc_port = port; desc.sockfd = -1; + + if (metadata_conn_string == P2PHANDSHAKE) { + // use random port when no port is specified + if (port == getDefaultHandshakePort()) { + desc.rpc_port = findAvailableTcpPort(desc.sockfd); + if (desc.rpc_port == 0) { + LOG(ERROR) + << "not valid port for serving local TCP service"; + return -1; + } + } + local_server_name_ = + desc.ip_or_host_name + ":" + std::to_string(desc.rpc_port); + } } else { (void)(ip_or_host_name); auto *ip_address = getenv("MC_TCP_BIND_ADDRESS"); @@ -59,13 +74,6 @@ int TransferEngine::init(const std::string &metadata_conn_string, << " and port " << desc.rpc_port << " for serving local TCP service"; - if (metadata_conn_string == P2PHANDSHAKE) { - local_server_name_ = - desc.ip_or_host_name + ":" + std::to_string(desc.rpc_port); - } else { - local_server_name_ = local_server_name; - } - metadata_ = std::make_shared(metadata_conn_string); multi_transports_ = std::make_shared(metadata_, local_server_name_); @@ -133,6 +141,8 @@ Transport *TransferEngine::installTransport(const std::string &proto, int TransferEngine::uninstallTransport(const std::string &proto) { return 0; } +int TransferEngine::getRpcPort() { return metadata_->localRpcMeta().rpc_port; } + Transport::SegmentHandle TransferEngine::openSegment( const std::string &segment_name) { if (segment_name.empty()) return ERR_INVALID_ARGUMENT; diff --git a/mooncake-transfer-engine/src/transport/tcp_transport/tcp_transport.cpp b/mooncake-transfer-engine/src/transport/tcp_transport/tcp_transport.cpp index 3d943bf66..8ac8b7283 100644 --- a/mooncake-transfer-engine/src/transport/tcp_transport/tcp_transport.cpp +++ b/mooncake-transfer-engine/src/transport/tcp_transport/tcp_transport.cpp @@ -220,9 +220,6 @@ int TcpTransport::install(std::string &local_server_name, return -1; } - if (meta->localRpcMeta().sockfd) { - close(meta->localRpcMeta().sockfd); - } int tcp_port = meta->localRpcMeta().rpc_port + 1; LOG(INFO) << "TcpTransport: listen on port " << tcp_port; context_ = new TcpContext(tcp_port);