Skip to content

[TransferEngine] feature: support p2phandshake without metadata server #240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 45 additions & 29 deletions mooncake-integration/transfer_engine/transfer_engine_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ std::pair<std::string, std::string> parseConnectionString(
if (pos != std::string::npos) {
proto = conn_string.substr(0, pos);
domain = conn_string.substr(pos + 3);
} else if (conn_string == P2PHANDSHAKE) {
proto = "";
domain = P2PHANDSHAKE;
} else {
domain = conn_string;
}
Expand All @@ -70,43 +73,47 @@ std::pair<std::string, std::string> parseConnectionString(
return result;
}

std::string buildConnString(const std::string &metadata_type,
const std::string &metadata_server) {
if (metadata_server == P2PHANDSHAKE) {
return P2PHANDSHAKE;
}

std::string conn_string = metadata_server;
if (conn_string.find("://") == std::string::npos)
conn_string = metadata_type + "://" + metadata_server;
return conn_string;
}

int TransferEnginePy::initialize(const char *local_hostname,
const char *metadata_server, const char *protocol,
const char *device_name) {
const char *metadata_server,
const char *protocol,
const char *device_name) {
auto conn_string = parseConnectionString(metadata_server);
return initializeExt(local_hostname, conn_string.second.c_str(), protocol,
device_name, conn_string.first.c_str());
}

int TransferEnginePy::initializeExt(const char *local_hostname,
const char *metadata_server,
const char *protocol, const char *device_name,
const char *metadata_type) {
std::string conn_string = metadata_server;
if (conn_string.find("://") == std::string::npos)
conn_string =
std::string(metadata_type) + "://" + std::string(metadata_server);
const char *metadata_server,
const char *protocol,
const char *device_name,
const char *metadata_type) {
std::string conn_string = buildConnString(metadata_type, metadata_server);

// TODO: remove `false` in the feature, it's for keep same API in SGLang.
engine_ = std::make_unique<TransferEngine>(false);
if (getenv("MC_LEGACY_RPC_PORT_BINDING")) {
auto hostname_port = parseHostNameWithPort(local_hostname);
int ret =
engine_->init(conn_string, local_hostname,
hostname_port.first.c_str(), hostname_port.second);
if (ret) return -1;
} else {
// the last two params are unused
int ret = engine_->init(conn_string, local_hostname, "", 0);
if (ret) return -1;
}
// the last two params are unused
int ret = engine_->init(conn_string, local_hostname, "", 0);
if (ret) return -1;

xport_ = nullptr;
if (strcmp(protocol, "rdma") == 0) {
auto device_names = formatDeviceNames(device_name);
std::string nic_priority_matrix =
"{\"cpu:0\": [[" + device_names + "], []],"
"\"cuda:0\": [[" + device_names + "], []]}";
std::string nic_priority_matrix = "{\"cpu:0\": [[" + device_names +
"], []],"
"\"cuda:0\": [[" +
device_names + "], []]}";
void **args = (void **)malloc(2 * sizeof(void *));
args[0] = (void *)nic_priority_matrix.c_str();
args[1] = nullptr;
Expand All @@ -124,6 +131,8 @@ int TransferEnginePy::initializeExt(const char *local_hostname,
return 0;
}

int TransferEnginePy::getRpcPort() { return engine_->getRpcPort(); }

char *TransferEnginePy::allocateRawBuffer(size_t capacity) {
auto buffer = malloc(capacity);
if (!buffer) return nullptr;
Expand Down Expand Up @@ -193,8 +202,10 @@ int TransferEnginePy::freeManagedBuffer(uintptr_t buffer_addr, size_t length) {
return 0;
}

int TransferEnginePy::transferSyncWrite(const char *target_hostname, uintptr_t buffer,
uintptr_t peer_buffer_address, size_t length) {
int TransferEnginePy::transferSyncWrite(const char *target_hostname,
uintptr_t buffer,
uintptr_t peer_buffer_address,
size_t length) {
Transport::SegmentHandle handle;
if (handle_map_.count(target_hostname)) {
handle = handle_map_[target_hostname];
Expand Down Expand Up @@ -229,8 +240,10 @@ int TransferEnginePy::transferSyncWrite(const char *target_hostname, uintptr_t b
}
}

int TransferEnginePy::transferSyncRead(const char *target_hostname, uintptr_t buffer,
uintptr_t peer_buffer_address, size_t length) {
int TransferEnginePy::transferSyncRead(const char *target_hostname,
uintptr_t buffer,
uintptr_t peer_buffer_address,
size_t length) {
Transport::SegmentHandle handle;
if (handle_map_.count(target_hostname)) {
handle = handle_map_[target_hostname];
Expand Down Expand Up @@ -348,13 +361,16 @@ PYBIND11_MODULE(engine, m) {
.def(py::init<>())
.def("initialize", &TransferEnginePy::initialize)
.def("initialize_ext", &TransferEnginePy::initializeExt)
.def("allocate_managed_buffer", &TransferEnginePy::allocateManagedBuffer)
.def("get_rpc_port", &TransferEnginePy::getRpcPort)
.def("allocate_managed_buffer",
&TransferEnginePy::allocateManagedBuffer)
.def("free_managed_buffer", &TransferEnginePy::freeManagedBuffer)
.def("transfer_sync_write", &TransferEnginePy::transferSyncWrite)
.def("transfer_sync_read", &TransferEnginePy::transferSyncRead)
.def("transfer_sync", &TransferEnginePy::transferSync)
.def("write_bytes_to_buffer", &TransferEnginePy::writeBytesToBuffer)
.def("read_bytes_from_buffer", &TransferEnginePy::readBytesFromBuffer)
.def("read_bytes_from_buffer",
&TransferEnginePy::readBytesFromBuffer)
.def("register_memory", &TransferEnginePy::registerMemory)
.def("unregister_memory", &TransferEnginePy::unregisterMemory)
.def("get_first_buffer_address",
Expand Down
14 changes: 8 additions & 6 deletions mooncake-integration/transfer_engine/transfer_engine_py.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,8 @@ const static size_t kSlabSizeKB[] = {

class TransferEnginePy {
public:
enum class TransferOpcode {
READ = 0,
WRITE = 1
};
enum class TransferOpcode { READ = 0, WRITE = 1 };

public:
TransferEnginePy();

Expand All @@ -57,6 +55,8 @@ class TransferEnginePy {
const char *protocol, const char *device_name,
const char *metadata_type);

int getRpcPort();

uintptr_t allocateManagedBuffer(size_t length);

int freeManagedBuffer(uintptr_t user_tensor, size_t length);
Expand All @@ -68,10 +68,11 @@ class TransferEnginePy {
uintptr_t peer_buffer_address, size_t length);

int transferSync(const char *target_hostname, uintptr_t buffer,
uintptr_t peer_buffer_address, size_t length, TransferOpcode opcode);
uintptr_t peer_buffer_address, size_t length,
TransferOpcode opcode);

uintptr_t getFirstBufferAddress(const std::string &segment_name);

int writeBytesToBuffer(uintptr_t dest_address, char *src_ptr,
size_t length) {
memcpy((void *)dest_address, (void *)src_ptr, length);
Expand All @@ -90,6 +91,7 @@ class TransferEnginePy {

// must be called before TransferEnginePy::~TransferEnginePy()
int unregisterMemory(uintptr_t buffer_addr);

private:
char *allocateRawBuffer(size_t capacity);

Expand Down
14 changes: 3 additions & 11 deletions mooncake-integration/vllm/vllm_adaptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,9 @@ int VLLMAdaptor::initializeExt(const char *local_hostname,

// TODO: remove `false` in the feature, it's for keep same API in vllm.
engine_ = std::make_unique<TransferEngine>(false);
if (getenv("MC_LEGACY_RPC_PORT_BINDING")) {
auto hostname_port = parseHostNameWithPort(local_hostname);
int ret =
engine_->init(conn_string, local_hostname,
hostname_port.first.c_str(), hostname_port.second);
if (ret) return -1;
} else {
// the last two params are unused
int ret = engine_->init(conn_string, local_hostname, "", 0);
if (ret) return -1;
}
// the last two params are unused
int ret = engine_->init(conn_string, local_hostname, "", 0);
if (ret) return -1;

xport_ = nullptr;
if (strcmp(protocol, "rdma") == 0) {
Expand Down
6 changes: 4 additions & 2 deletions mooncake-transfer-engine/example/http-metadata-server/go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
module github.com/kvcache-ai/Mooncake/mooncake-transfer-engine/example/http-metadata-server

go 1.22.9
go 1.23.0

toolchain go1.23.8

require github.com/gin-gonic/gin v1.10.0

Expand All @@ -25,7 +27,7 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.31.0 // indirect
golang.org/x/crypto v0.35.0 // indirect
golang.org/x/net v0.36.0 // indirect
golang.org/x/sys v0.30.0 // indirect
golang.org/x/text v0.22.0 // indirect
Expand Down
58 changes: 49 additions & 9 deletions mooncake-transfer-engine/include/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@
namespace mooncake {
const static int LOCAL_SEGMENT_ID = 0;

enum class HandShakeRequestType {
Connection = 0,
Metadata = 1,
// placeholder for old protocol without RequestType
OldProtocol = 0xff,
};

static inline int bindToSocket(int socket_id) {
if (unlikely(numa_available() < 0)) {
LOG(ERROR) << "The platform does not support NUMA";
Expand Down Expand Up @@ -140,27 +147,60 @@ static inline ssize_t readFully(int fd, void *buf, size_t len) {
return len;
}

static inline int writeString(int fd, const std::string &str) {
uint64_t length = str.size();
static inline int writeString(int fd, const HandShakeRequestType type,
const std::string &str) {
uint8_t byte = static_cast<uint8_t>(type);
LOG(INFO) << "writeString: type " << (int)byte << ", str(" << str.size()
<< "): " << str;
uint64_t length =
str.size() +
(type == HandShakeRequestType::OldProtocol ? 0 : sizeof(byte));
if (writeFully(fd, &length, sizeof(length)) != (ssize_t)sizeof(length))
return ERR_SOCKET;
if (type != HandShakeRequestType::OldProtocol) {
if (writeFully(fd, &byte, sizeof(byte)) != (ssize_t)sizeof(byte))
return ERR_SOCKET;
}
if (writeFully(fd, str.data(), length) != (ssize_t)length)
return ERR_SOCKET;
return 0;
}

static inline std::string readString(int fd) {
static inline std::pair<HandShakeRequestType, std::string> readString(int fd) {
HandShakeRequestType type = HandShakeRequestType::Connection;

const static size_t kMaxLength = 1ull << 20;
uint64_t length = 0;
if (readFully(fd, &length, sizeof(length)) != (ssize_t)sizeof(length))
return "";
if (length > kMaxLength) return "";
ssize_t n = readFully(fd, &length, sizeof(length));
if (n != (ssize_t)sizeof(length)) {
LOG(ERROR) << "readString: failed to read length, got: " << n;
return {type, ""};
}

if (length > kMaxLength) {
LOG(ERROR) << "readString: too large length from socket: " << length;
return {type, ""};
}

std::string str;
std::vector<char> buffer(length);
if (readFully(fd, buffer.data(), length) != (ssize_t)length) return "";
n = readFully(fd, buffer.data(), length);
if (n != (ssize_t)length) {
LOG(ERROR) << "readString: unexpected length, got: " << n
<< ", expected: " << length;
return {type, ""};
}

if (buffer[0] <= static_cast<char>(HandShakeRequestType::Metadata)) {
type = static_cast<HandShakeRequestType>(buffer[0]);
str.assign(buffer.data() + sizeof(char), length - sizeof(char));
} else {
type = HandShakeRequestType::OldProtocol;
// Old protocol, no type
str.assign(buffer.data(), length);
}

str.assign(buffer.data(), length);
return str;
return {type, str};
}

const static std::string NIC_PATH_DELIM = "@";
Expand Down
4 changes: 3 additions & 1 deletion mooncake-transfer-engine/include/transfer_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class TransferEngine {

int init(const std::string &metadata_conn_string,
const std::string &local_server_name,
const std::string &ip_or_host_name = "",
const std::string &ip_or_host_name = "",
uint64_t rpc_port = 12345);

int freeEngine();
Expand All @@ -64,6 +64,8 @@ class TransferEngine {

int uninstallTransport(const std::string &proto);

int getRpcPort();

SegmentHandle openSegment(const std::string &segment_name);

int closeSegment(SegmentHandle handle);
Expand Down
11 changes: 10 additions & 1 deletion mooncake-transfer-engine/include/transfer_metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ namespace mooncake {
struct MetadataStoragePlugin;
struct HandShakePlugin;

#define P2PHANDSHAKE "P2PHANDSHAKE"

class TransferMetadata {
public:
struct DeviceDesc {
Expand Down Expand Up @@ -73,7 +75,7 @@ class TransferMetadata {
struct RpcMetaDesc {
std::string ip_or_host_name;
uint16_t rpc_port;
int sockfd; // local cache
int sockfd; // local cache
};

struct HandShakeDesc {
Expand Down Expand Up @@ -134,6 +136,13 @@ class TransferMetadata {
HandShakeDesc &peer_desc);

private:
int encodeSegmentDesc(const SegmentDesc &desc, Json::Value &segmentJSON);
std::shared_ptr<TransferMetadata::SegmentDesc> decodeSegmentDesc(
Json::Value &segmentJSON, const std::string &segment_name);
int receivePeerMetadata(const Json::Value &peer_json,
Json::Value &local_json);

bool p2p_handshake_mode_{false};
// local cache
RWSpinlock segment_lock_;
std::unordered_map<uint64_t, std::shared_ptr<SegmentDesc>>
Expand Down
20 changes: 15 additions & 5 deletions mooncake-transfer-engine/include/transfer_metadata_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,29 @@ struct HandShakePlugin {
HandShakePlugin() {}
virtual ~HandShakePlugin() {}

// When accept a new connection, this function will be called.
// The first param represents peer endpoint's attributes, while
// the second param represents local endpoint's attributes
// When accept a new connection/metadata request, this function will be
// called. The first param represents peer attributes, while the
// second param represents local attributes.
using OnReceiveCallBack =
std::function<int(const Json::Value &, Json::Value &)>;

virtual int startDaemon(OnReceiveCallBack on_recv_callback,
uint16_t listen_port, int sockfd) = 0;
virtual int startDaemon(uint16_t listen_port, int sockfd) = 0;

// Connect to peer endpoint, and wait for receiving
// peer endpoint's attributes
virtual int send(std::string ip_or_host_name, uint16_t rpc_port,
const Json::Value &local, Json::Value &peer) = 0;

// Exchange metadata with remote peer.
virtual int exchangeMetadata(std::string ip_or_host_name, uint16_t rpc_port,
const Json::Value &local_metadata,
Json::Value &peer_metadata) = 0;

// Register callback function for receiving a new connection.
virtual void registerOnConnectionCallBack(OnReceiveCallBack callback) = 0;

// Register callback function for receiving metadata exchange request.
virtual void registerOnMetadataCallBack(OnReceiveCallBack callback) = 0;
};

std::vector<std::string> findLocalIpAddresses();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class RdmaTransport : public Transport {
// TRANSFER

Status submitTransfer(BatchID batch_id,
const std::vector<TransferRequest> &entries) override;
const std::vector<TransferRequest> &entries) override;

Status submitTransferTask(
const std::vector<TransferRequest *> &request_list,
Expand Down
Loading
Loading