Skip to content

Commit 33b16ad

Browse files
authored
Distinguish bootstrap key only in decode server (#5422)
1 parent ffde65a commit 33b16ad

File tree

3 files changed

+18
-29
lines changed

3 files changed

+18
-29
lines changed

python/sglang/srt/disaggregation/decode.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,7 @@
2828
import torch
2929
from torch.distributed import ProcessGroup
3030

31-
from sglang.srt.disaggregation.base import (
32-
BaseKVManager,
33-
BaseKVReceiver,
34-
BaseKVSender,
35-
KVArgs,
36-
KVPoll,
37-
)
31+
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
3832
from sglang.srt.disaggregation.utils import (
3933
DisaggregationMode,
4034
KVClassType,

python/sglang/srt/disaggregation/mooncake/conn.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def _register_to_bootstrap(self):
329329
"role": "Prefill",
330330
"rank_ip": get_local_ip_by_remote(),
331331
"rank_port": self.rank_port,
332-
"bootstrap_key": f"{bootstrap_server_url}_{self.kv_args.engine_rank}",
332+
"engine_rank": self.kv_args.engine_rank,
333333
}
334334

335335
try:
@@ -400,28 +400,29 @@ def __init__(
400400
self.session_id = self.kv_mgr.get_session_id()
401401
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
402402

403-
self.bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
403+
# NOTE: key distinguished by bootstrap_addr and engine_rank
404+
bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
404405

405-
if self.bootstrap_key not in self.kv_mgr.connection_pool:
406+
if bootstrap_key not in self.kv_mgr.connection_pool:
406407
self.bootstrap_info = self._get_bootstrap_info_from_server(
407-
self.bootstrap_key
408+
self.kv_mgr.kv_args.engine_rank
408409
)
409410
if self.bootstrap_info is None:
410411
logger.error(
411412
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
412413
)
413414
else:
414-
self.kv_mgr.connection_pool[self.bootstrap_key] = self.bootstrap_info
415+
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info
415416
else:
416-
self.bootstrap_info = self.kv_mgr.connection_pool[self.bootstrap_key]
417+
self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key]
417418

418419
assert self.bootstrap_info is not None
419420
self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput)
420421

421-
def _get_bootstrap_info_from_server(self, bootstrap_key: str):
422+
def _get_bootstrap_info_from_server(self, engine_rank):
422423
"""Fetch the bootstrap info from the bootstrap server."""
423424
try:
424-
url = f"http://{self.bootstrap_addr}/route?bootstrap_key={bootstrap_key}"
425+
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}"
425426
response = requests.get(url)
426427
if response.status_code == 200:
427428
bootstrap_info = response.json()
@@ -556,28 +557,28 @@ async def _handle_route_put(self, request: web.Request):
556557
role = data["role"]
557558
rank_ip = data["rank_ip"]
558559
rank_port = int(data["rank_port"])
559-
bootstrap_key = data["bootstrap_key"]
560+
engine_rank = int(data["engine_rank"])
560561

561562
# Add lock to make sure thread-safe
562563
if role == "Prefill":
563-
self.prefill_port_table[bootstrap_key] = {
564+
self.prefill_port_table[engine_rank] = {
564565
"rank_ip": rank_ip,
565566
"rank_port": rank_port,
566567
}
567568
logger.debug(
568-
f"Registered Prefill bootstrap_key: {bootstrap_key} with rank_ip: {rank_ip} and rank_port: {rank_port}"
569+
f"Registered Prefill boostrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
569570
)
570571

571572
return web.Response(text="OK", status=200)
572573

573574
async def _handle_route_get(self, request: web.Request):
574-
bootstrap_key = request.query.get("bootstrap_key")
575-
if not bootstrap_key:
576-
return web.Response(text="Missing bootstrap_key", status=400)
575+
engine_rank = request.query.get("engine_rank")
576+
if not engine_rank:
577+
return web.Response(text="Missing rank", status=400)
577578

578579
# Find corresponding prefill info
579580
async with self.lock:
580-
bootstrap_info = self.prefill_port_table.get(bootstrap_key)
581+
bootstrap_info = self.prefill_port_table.get(int(engine_rank))
581582

582583
if bootstrap_info is not None:
583584
return web.json_response(bootstrap_info, status=200)

python/sglang/srt/disaggregation/prefill.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,7 @@
2424

2525
import torch
2626

27-
from sglang.srt.disaggregation.base import (
28-
BaseKVManager,
29-
BaseKVReceiver,
30-
BaseKVSender,
31-
KVArgs,
32-
KVPoll,
33-
)
27+
from sglang.srt.disaggregation.base import BaseKVManager, KVArgs, KVPoll
3428
from sglang.srt.disaggregation.utils import (
3529
DisaggregationMode,
3630
KVClassType,

0 commit comments

Comments
 (0)