Skip to content

Commit dca90f1

Browse files
authored
[PD] Remove the requirement of config file for mooncake backend (#5460)
1 parent 0961fee commit dca90f1

File tree

5 files changed

+44
-108
lines changed

5 files changed

+44
-108
lines changed

python/sglang/srt/disaggregation/decode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def _init_kv_manager(self) -> BaseKVManager:
121121
kv_args.aux_item_lens = [
122122
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
123123
]
124-
kv_args.ib_device = "mock-ib-device"
124+
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
125125
kv_args.gpu_id = self.scheduler.gpu_id
126126
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
127127
kv_manager = kv_manager_class(

python/sglang/srt/disaggregation/mooncake/conn.py

Lines changed: 5 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,12 @@ def __init__(
9999
disaggregation_mode: DisaggregationMode,
100100
server_args: ServerArgs,
101101
):
102-
self.engine = MooncakeTransferEngine()
103102
self.kv_args = args
103+
self.engine = MooncakeTransferEngine(
104+
hostname=get_local_ip_by_remote(),
105+
gpu_id=self.kv_args.gpu_id,
106+
ib_device=self.kv_args.ib_device,
107+
)
104108
self.disaggregation_mode = disaggregation_mode
105109
# for p/d multi node infer
106110
self.bootstrap_port = server_args.disaggregation_bootstrap_port
@@ -503,52 +507,8 @@ def run(self):
503507
self.thread.start()
504508

505509
def _setup_routes(self):
506-
self.app.router.add_route("*", "/metadata", self._handle_metadata)
507510
self.app.router.add_route("*", "/route", self._handle_route)
508511

509-
async def _handle_metadata(self, request: web.Request):
510-
key = request.query.get("key", "")
511-
512-
if request.method == "GET":
513-
return await self._handle_metadata_get(key)
514-
elif request.method == "PUT":
515-
return await self._handle_metadata_put(key, request)
516-
elif request.method == "DELETE":
517-
return await self._handle_metadata_delete(key)
518-
return web.Response(
519-
text="Method not allowed", status=405, content_type="application/json"
520-
)
521-
522-
async def _handle_metadata_get(self, key):
523-
async with self.lock:
524-
value = self.store.get(key)
525-
if value is None:
526-
return web.Response(
527-
text="metadata not found", status=404, content_type="application/json"
528-
)
529-
return web.Response(body=value, status=200, content_type="application/json")
530-
531-
async def _handle_metadata_put(self, key, request):
532-
data = await request.read()
533-
async with self.lock:
534-
self.store[key] = data
535-
return web.Response(
536-
text="metadata updated", status=200, content_type="application/json"
537-
)
538-
539-
async def _handle_metadata_delete(self, key):
540-
async with self.lock:
541-
if key not in self.store:
542-
return web.Response(
543-
text="metadata not found",
544-
status=404,
545-
content_type="application/json",
546-
)
547-
del self.store[key]
548-
return web.Response(
549-
text="metadata deleted", status=200, content_type="application/json"
550-
)
551-
552512
async def _handle_route(self, request: web.Request):
553513
method = request.method
554514
if method == "PUT":
Lines changed: 30 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,14 @@
11
import json
22
import logging
3-
import os
4-
import uuid
53
from dataclasses import dataclass
4+
from typing import Optional
65

76
logger = logging.getLogger(__name__)
87

98

10-
@dataclass
11-
class MooncakeTransferEngineConfig:
12-
local_hostname: str
13-
metadata_server: str
14-
protocol: str
15-
device_name: str
16-
17-
@staticmethod
18-
def from_file(file_path: str) -> "MooncakeTransferEngineConfig":
19-
"""Load the config from a JSON file."""
20-
with open(file_path) as fin:
21-
config = json.load(fin)
22-
return MooncakeTransferEngineConfig(
23-
local_hostname=config.get("local_hostname", None),
24-
metadata_server=config.get("metadata_server"),
25-
protocol=config.get("protocol", "rdma"),
26-
device_name=config.get("device_name", ""),
27-
)
28-
29-
@staticmethod
30-
def load_from_env() -> "MooncakeTransferEngineConfig":
31-
"""Load config from a file specified in the environment variable."""
32-
config_file_path = os.getenv("MOONCAKE_CONFIG_PATH")
33-
if config_file_path is None:
34-
raise ValueError(
35-
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
36-
)
37-
return MooncakeTransferEngineConfig.from_file(config_file_path)
38-
39-
409
class MooncakeTransferEngine:
4110

42-
def __init__(self):
11+
def __init__(self, hostname: str, gpu_id: int, ib_device: Optional[str] = None):
4312
try:
4413
from mooncake.engine import TransferEngine
4514
except ImportError as e:
@@ -50,43 +19,43 @@ def __init__(self):
5019
) from e
5120

5221
self.engine = TransferEngine()
22+
self.hostname = hostname
23+
self.gpu_id = gpu_id
24+
self.ib_device = ib_device
5325

54-
try:
55-
self.config = MooncakeTransferEngineConfig.load_from_env()
56-
logger.info("Mooncake Configuration loaded successfully.")
57-
except ValueError as e:
58-
logger.error(e)
59-
raise
60-
except Exception as exc:
61-
logger.error("An error occurred while loading the configuration: %s", exc)
62-
raise
63-
64-
self.config = MooncakeTransferEngineConfig.load_from_env()
65-
66-
session_suffix = "_" + str(uuid.uuid4())
67-
self.session_id = self.config.local_hostname + session_suffix
6826
self.initialize(
69-
self.session_id,
70-
self.config.metadata_server,
71-
self.config.protocol,
72-
self.config.device_name,
27+
hostname=self.hostname,
28+
device_name=self.ib_device,
7329
)
30+
self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}"
7431

7532
def register(self, ptr, length):
76-
self.engine.register_memory(ptr, length)
33+
ret_value = self.engine.register_memory(ptr, length)
34+
if ret_value != 0:
35+
logger.error("Mooncake memory registration failed.")
36+
raise RuntimeError("Mooncake memory registration failed.")
7737

7838
def deregister(self, ptr):
79-
self.engine.unregister_memory(ptr)
39+
ret_value = self.engine.unregister_memory(ptr)
40+
if ret_value != 0:
41+
logger.error("Mooncake memory deregistration failed.")
42+
raise RuntimeError("Mooncake memory deregistration failed.")
8043

8144
def initialize(
8245
self,
83-
local_hostname: str,
84-
metadata_server: str,
85-
protocol: str,
86-
device_name: str,
46+
hostname: str,
47+
device_name: Optional[str],
8748
) -> None:
8849
"""Initialize the mooncake instance."""
89-
self.engine.initialize(local_hostname, metadata_server, protocol, device_name)
50+
ret_value = self.engine.initialize(
51+
hostname,
52+
"P2PHANDSHAKE",
53+
"rdma",
54+
device_name if device_name is not None else "",
55+
)
56+
if ret_value != 0:
57+
logger.error("Mooncake Transfer Engine initialization failed.")
58+
raise RuntimeError("Mooncake Transfer Engine initialization failed.")
9059

9160
def transfer_sync(
9261
self, session_id: str, buffer: int, peer_buffer_address: int, length: int
@@ -97,12 +66,12 @@ def transfer_sync(
9766
session_id, buffer, peer_buffer_address, length
9867
)
9968
if ret < 0:
100-
logger.error("Transfer Return Error")
101-
raise Exception("Transfer Return Error")
69+
logger.error("Mooncake Transfer Engine Return Error.")
70+
raise RuntimeError("Mooncake Transfer Engine Return Error.")
10271
return ret
10372

10473
def get_localhost(self):
105-
return self.config.local_hostname
74+
return self.hostname
10675

10776
def get_session_id(self):
10877
return self.session_id

python/sglang/srt/disaggregation/prefill.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def _init_kv_manager(self) -> BaseKVManager:
103103
kv_args.aux_item_lens = [
104104
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
105105
]
106-
kv_args.ib_device = "mock-ib-device"
106+
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
107107
kv_args.gpu_id = self.scheduler.gpu_id
108108
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
109109
kv_manager = kv_manager_class(

python/sglang/srt/server_args.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ class ServerArgs:
196196
disaggregation_mode: str = "null"
197197
disaggregation_bootstrap_port: int = 8998
198198
disaggregation_transfer_backend: str = "mooncake"
199+
disaggregation_ib_device: Optional[str] = None
199200

200201
def __post_init__(self):
201202
# Expert parallelism
@@ -1193,6 +1194,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
11931194
default=ServerArgs.disaggregation_transfer_backend,
11941195
help="The backend for disaggregation transfer. Default is mooncake.",
11951196
)
1197+
parser.add_argument(
1198+
"--disaggregation-ib-device",
1199+
type=str,
1200+
default=ServerArgs.disaggregation_ib_device,
1201+
help="The ib device for disaggregation transfer. Default is None, it will be detected automatically if using the mooncake backend.",
1202+
)
11961203

11971204
@classmethod
11981205
def from_cli_args(cls, args: argparse.Namespace):

0 commit comments

Comments
 (0)