Skip to content

Commit e1c10bc

Browse files
ympcMarkUNIDY2002
andcommitted
Achieve fault tolerance at the DP level
Co-authored-by: UNIDY2002 <unidy2002@outlook.com>
1 parent 351cfd6 commit e1c10bc

5 files changed

Lines changed: 56 additions & 4 deletions

File tree

python/sglang/srt/managers/data_parallel_controller.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from sglang.srt.environ import envs
3030
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
3131
from sglang.srt.managers.io_struct import (
32+
ActiveRanksOutput,
3233
BlockReqInput,
3334
TokenizedEmbeddingReqInput,
3435
TokenizedGenerateReqInput,
@@ -158,6 +159,7 @@ def __init__(
158159
# Launch data parallel workers
159160
self.scheduler_procs = []
160161
self.workers: List[zmq.Socket] = [None] * server_args.dp_size
162+
self.status: List[int] = [1] * server_args.dp_size
161163

162164
if server_args.enable_dp_attention:
163165
self.launch_dp_attention_schedulers(server_args, port_args)
@@ -179,8 +181,9 @@ def __init__(
179181
start_cpu_monitor_thread("data_parallel_controller")
180182

181183
def send_to_all_workers(self, obj):
182-
for worker in self.workers:
183-
worker.send_pyobj(obj)
184+
for i, worker in enumerate(self.workers):
185+
if self.status[i] == 1:
186+
worker.send_pyobj(obj)
184187

185188
def send_control_message(self, obj):
186189
# Send control messages to first worker of tp group
@@ -190,6 +193,9 @@ def send_control_message(self, obj):
190193
def handle_load_update_req(self, obj):
191194
self.dp_budget.update_budget(obj)
192195

196+
def update_active_ranks(self, ranks: ActiveRanksOutput):
197+
self.status = ranks.status
198+
193199
def dispatching_with_trace(self, req: Req):
194200
if self.server_args.enable_trace:
195201
trace_set_proc_propagate_context(req.rid, req.trace_context)
@@ -208,6 +214,7 @@ def init_dispatcher(self):
208214
(TokenizedEmbeddingReqInput, self.dispatching_with_trace),
209215
(BlockReqInput, self.send_to_all_workers),
210216
(WatchLoadUpdateReq, self.handle_load_update_req),
217+
(ActiveRanksOutput, self.update_active_ranks),
211218
]
212219
)
213220
self._request_dispatcher.add_fallback_fn(self.send_control_message)
@@ -479,8 +486,17 @@ def round_robin_scheduler(self, req: Req):
479486
if self.maybe_external_dp_rank_routing(req):
480487
return
481488

482-
self.workers[self.round_robin_counter].send_pyobj(req)
483-
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
489+
while True:
490+
if self.status[self.round_robin_counter] == 1:
491+
logger.info(f"Choose worker {self.round_robin_counter}")
492+
self.workers[self.round_robin_counter].send_pyobj(req)
493+
self.round_robin_counter = (self.round_robin_counter + 1) % len(
494+
self.workers
495+
)
496+
break
497+
self.round_robin_counter = (self.round_robin_counter + 1) % len(
498+
self.workers
499+
)
484500

485501
def follow_bootstrap_room_scheduler(self, req: Req):
486502
if self.maybe_external_dp_rank_routing(req):

python/sglang/srt/managers/io_struct.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,6 +1434,11 @@ def __post_init__(self):
14341434
self.rid = ""
14351435

14361436

1437+
@dataclass
1438+
class ActiveRanksOutput(BaseReq):
1439+
status: List[int]
1440+
1441+
14371442
@dataclass
14381443
class GetInternalStateReq(BaseReq):
14391444
pass

python/sglang/srt/managers/scheduler.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
from sglang.srt.layers.quantization.fp8_utils import initialize_fp8_gemm_config
6969
from sglang.srt.managers.io_struct import (
7070
AbortReq,
71+
ActiveRanksOutput,
7172
BaseBatchReq,
7273
BaseReq,
7374
BatchTokenizedEmbeddingReqInput,
@@ -2273,6 +2274,19 @@ def run_batch(
22732274
for req in batch.reqs:
22742275
req.time_stats.prefill_end_time_host = current_time
22752276

2277+
if (
2278+
self.server_args.enable_dp_attention
2279+
and self.server_args.elastic_ep_backend == "mooncake"
2280+
):
2281+
# Get the tensors indicating rank activeness
2282+
tp_active_ranks = self.tp_group.active_ranks.detach().cpu().numpy()
2283+
tp_active_ranks_cpu = self.tp_group.active_ranks_cpu.detach().numpy()
2284+
tp_active_ranks &= tp_active_ranks_cpu
2285+
dp_active_ranks = tp_active_ranks.reshape(self.dp_size, -1).prod(axis=1)
2286+
self.send_to_tokenizer.send_output(
2287+
ActiveRanksOutput(status=dp_active_ranks.tolist())
2288+
)
2289+
22762290
return ret
22772291

22782292
def launch_batch_sample_if_needed(

python/sglang/srt/managers/scheduler_dp_attn_mixin.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
import torch
77

88
from sglang.srt.batch_overlap.two_batch_overlap import TboDPAttentionPreparer
9+
from sglang.srt.distributed.parallel_state import get_tp_group
910
from sglang.srt.environ import envs
1011
from sglang.srt.managers.schedule_batch import ScheduleBatch
1112
from sglang.srt.metrics.collector import DPCooperationInfo
13+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
1214
from sglang.srt.utils.common import require_mlp_tp_gather
1315

1416
if TYPE_CHECKING:
@@ -66,6 +68,15 @@ def all_gather(self, device, group: torch.distributed.ProcessGroup):
6668
local_info_tensor,
6769
group=group,
6870
)
71+
if device == "cpu":
72+
tp_active_ranks = get_tp_group().active_ranks_cpu
73+
else:
74+
tp_active_ranks = get_tp_group().active_ranks
75+
global_info_tensor.view(-1, 6)[tp_active_ranks == 0, :] = torch.tensor(
76+
[0, 1, 0, 0, 1, ForwardMode.IDLE.value],
77+
device=global_info_tensor.device,
78+
dtype=global_info_tensor.dtype,
79+
)
6980

7081
tp0_info = global_info_tensor[:, 0, :]
7182
self.tp0_info = tp0_info
@@ -149,6 +160,7 @@ def prepare_mlp_sync_batch_raw(
149160
if len(offload_tags) == 0 and disable_overlap_schedule:
150161
group = tp_group.device_group
151162
device = tp_group.device
163+
torch.distributed.barrier(group=tp_group.cpu_group)
152164
else:
153165
group = tp_group.cpu_group
154166
device = "cpu"

python/sglang/srt/managers/tokenizer_manager.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from sglang.srt.managers.disagg_service import start_disagg_service
4848
from sglang.srt.managers.io_struct import (
4949
AbortReq,
50+
ActiveRanksOutput,
5051
BatchEmbeddingOutput,
5152
BatchMultimodalOutput,
5253
BatchStrOutput,
@@ -465,6 +466,7 @@ def init_request_dispatcher(self):
465466
(FreezeGCReq, lambda x: None),
466467
# For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
467468
(HealthCheckOutput, lambda x: None),
469+
(ActiveRanksOutput, self.update_active_ranks),
468470
]
469471
)
470472
self.init_communicators(self.server_args)
@@ -2104,6 +2106,9 @@ def _handle_abort_req(self, recv_obj: AbortReq):
21042106
state.out_list.append(out)
21052107
state.event.set()
21062108

2109+
def update_active_ranks(self, ranks: ActiveRanksOutput):
2110+
self.send_to_scheduler.send_pyobj(ranks)
2111+
21072112
def _handle_open_session_req_output(self, recv_obj):
21082113
self.session_futures[recv_obj.session_id].set_result(
21092114
recv_obj.session_id if recv_obj.success else None

0 commit comments

Comments
 (0)