Skip to content

Commit b96cd09

Browse files
committed
add pp
1 parent c998d04 commit b96cd09

21 files changed

+1037
-299
lines changed

python/sglang/srt/entrypoints/engine.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ def __init__(self, **kwargs):
120120
server_args=server_args,
121121
port_args=port_args,
122122
)
123-
124123
self.server_args = server_args
125124
self.tokenizer_manager = tokenizer_manager
126125
self.scheduler_info = scheduler_info
@@ -295,7 +294,6 @@ def get_server_info(self):
295294
internal_states = loop.run_until_complete(
296295
self.tokenizer_manager.get_internal_state()
297296
)
298-
299297
return {
300298
**dataclasses.asdict(self.tokenizer_manager.server_args),
301299
**self.scheduler_info,
@@ -508,25 +506,44 @@ def _launch_subprocesses(
508506
)
509507

510508
scheduler_pipe_readers = []
511-
tp_size_per_node = server_args.tp_size // server_args.nnodes
509+
510+
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
511+
tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
512512
tp_rank_range = range(
513-
tp_size_per_node * server_args.node_rank,
514-
tp_size_per_node * (server_args.node_rank + 1),
513+
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
514+
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
515515
)
516-
for tp_rank in tp_rank_range:
517-
reader, writer = mp.Pipe(duplex=False)
518-
gpu_id = (
519-
server_args.base_gpu_id
520-
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
521-
)
522-
proc = mp.Process(
523-
target=run_scheduler_process,
524-
args=(server_args, port_args, gpu_id, tp_rank, None, writer),
525-
)
526-
with memory_saver_adapter.configure_subprocess():
527-
proc.start()
528-
scheduler_procs.append(proc)
529-
scheduler_pipe_readers.append(reader)
516+
517+
pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
518+
pp_rank_range = range(
519+
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
520+
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
521+
)
522+
523+
for pp_rank in pp_rank_range:
524+
for tp_rank in tp_rank_range:
525+
reader, writer = mp.Pipe(duplex=False)
526+
gpu_id = (
527+
server_args.base_gpu_id
528+
+ ((pp_rank % pp_size_per_node) * tp_size_per_node)
529+
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
530+
)
531+
proc = mp.Process(
532+
target=run_scheduler_process,
533+
args=(
534+
server_args,
535+
port_args,
536+
gpu_id,
537+
tp_rank,
538+
pp_rank,
539+
None,
540+
writer,
541+
),
542+
)
543+
with memory_saver_adapter.configure_subprocess():
544+
proc.start()
545+
scheduler_procs.append(proc)
546+
scheduler_pipe_readers.append(reader)
530547
else:
531548
# Launch the data parallel controller
532549
reader, writer = mp.Pipe(duplex=False)

python/sglang/srt/layers/dp_attention.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def initialize_dp_attention(
4343
tp_rank: int,
4444
tp_size: int,
4545
dp_size: int,
46+
pp_size: int,
4647
):
4748
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
4849

@@ -53,17 +54,19 @@ def initialize_dp_attention(
5354
)
5455

5556
if enable_dp_attention:
57+
local_rank = tp_rank % (tp_size // dp_size)
5658
_DP_SIZE = dp_size
5759
else:
60+
local_rank = tp_rank
5861
_DP_SIZE = 1
5962

6063
tp_group = get_tp_group()
6164
_ATTN_TP_GROUP = GroupCoordinator(
6265
[
6366
list(range(head, head + _ATTN_TP_SIZE))
64-
for head in range(0, tp_size, _ATTN_TP_SIZE)
67+
for head in range(0, pp_size * tp_size, _ATTN_TP_SIZE)
6568
],
66-
tp_group.local_rank,
69+
local_rank,
6770
torch.distributed.get_backend(tp_group.device_group),
6871
SYNC_TOKEN_IDS_ACROSS_TP,
6972
False,

python/sglang/srt/layers/utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import logging
2+
3+
import torch
4+
5+
logger = logging.getLogger(__name__)
6+
7+
8+
class PPMissingLayer(torch.nn.Identity):
9+
# Adapted from
10+
# https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1
11+
"""
12+
A placeholder layer for missing layers in a pipeline parallel model.
13+
"""
14+
15+
def __init__(self, *args, **kwargs):
16+
super().__init__()
17+
self.return_tuple = kwargs.get("return_tuple", False)
18+
19+
def forward(self, *args, **kwargs):
20+
"""
21+
Return the first arg from args or the first value from kwargs.
22+
23+
Wraps the input in a tuple if `self.return_tuple` is True.
24+
"""
25+
input = args[0] if args else next(iter(kwargs.values()))
26+
return (input,) if self.return_tuple else input

python/sglang/srt/managers/data_parallel_controller.py

Lines changed: 52 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -181,44 +181,62 @@ def launch_tensor_parallel_group(
181181
enable=server_args.enable_memory_saver
182182
)
183183

184-
# Launch tensor parallel scheduler processes
185184
scheduler_pipe_readers = []
186-
tp_size_per_node = server_args.tp_size // server_args.nnodes
185+
186+
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
187+
tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
187188
tp_rank_range = range(
188-
tp_size_per_node * server_args.node_rank,
189-
tp_size_per_node * (server_args.node_rank + 1),
189+
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
190+
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
191+
)
192+
193+
pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
194+
pp_rank_range = range(
195+
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
196+
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
190197
)
191-
for tp_rank in tp_rank_range:
192-
rank_port_args = port_args
193-
194-
if server_args.enable_dp_attention:
195-
# dp attention has different sharding logic
196-
_, _, dp_rank = compute_dp_attention_world_info(
197-
server_args.enable_dp_attention,
198-
tp_rank,
199-
server_args.tp_size,
200-
server_args.dp_size,
198+
199+
for pp_rank in pp_rank_range:
200+
for tp_rank in tp_rank_range:
201+
rank_port_args = port_args
202+
203+
if server_args.enable_dp_attention:
204+
# dp attention has different sharding logic
205+
_, _, dp_rank = compute_dp_attention_world_info(
206+
server_args.enable_dp_attention,
207+
tp_rank,
208+
server_args.tp_size,
209+
server_args.dp_size,
210+
)
211+
# compute zmq ports for this dp rank
212+
rank_port_args = PortArgs.init_new(server_args, dp_rank)
213+
# Data parallelism resues the tensor parallelism group,
214+
# so all dp ranks should use the same nccl port.
215+
rank_port_args.nccl_port = port_args.nccl_port
216+
217+
reader, writer = mp.Pipe(duplex=False)
218+
gpu_id = (
219+
server_args.base_gpu_id
220+
+ base_gpu_id
221+
+ ((pp_rank % pp_size_per_node) * tp_size_per_node)
222+
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
201223
)
202-
# compute zmq ports for this dp rank
203-
rank_port_args = PortArgs.init_new(server_args, dp_rank)
204-
# Data parallelism resues the tensor parallelism group,
205-
# so all dp ranks should use the same nccl port.
206-
rank_port_args.nccl_port = port_args.nccl_port
207-
208-
reader, writer = mp.Pipe(duplex=False)
209-
gpu_id = (
210-
server_args.base_gpu_id
211-
+ base_gpu_id
212-
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
213-
)
214-
proc = mp.Process(
215-
target=run_scheduler_process,
216-
args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),
217-
)
218-
with memory_saver_adapter.configure_subprocess():
219-
proc.start()
220-
self.scheduler_procs.append(proc)
221-
scheduler_pipe_readers.append(reader)
224+
proc = mp.Process(
225+
target=run_scheduler_process,
226+
args=(
227+
server_args,
228+
rank_port_args,
229+
gpu_id,
230+
tp_rank,
231+
pp_rank,
232+
dp_rank,
233+
writer,
234+
),
235+
)
236+
with memory_saver_adapter.configure_subprocess():
237+
proc.start()
238+
self.scheduler_procs.append(proc)
239+
scheduler_pipe_readers.append(reader)
222240

223241
# Wait for model to finish loading
224242
scheduler_info = []

python/sglang/srt/managers/schedule_batch.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -65,23 +65,26 @@
6565
# Put some global args for easy access
6666
global_server_args_dict = {
6767
"attention_backend": ServerArgs.attention_backend,
68-
"sampling_backend": ServerArgs.sampling_backend,
69-
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
70-
"torchao_config": ServerArgs.torchao_config,
71-
"enable_nan_detection": ServerArgs.enable_nan_detection,
72-
"enable_dp_attention": ServerArgs.enable_dp_attention,
73-
"enable_ep_moe": ServerArgs.enable_ep_moe,
74-
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
68+
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
7569
"deepep_mode": ServerArgs.deepep_mode,
7670
"device": ServerArgs.device,
77-
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
78-
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
71+
"disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
7972
"disable_radix_cache": ServerArgs.disable_radix_cache,
73+
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
74+
"enable_dp_attention": ServerArgs.enable_dp_attention,
75+
"enable_ep_moe": ServerArgs.enable_ep_moe,
76+
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
77+
"enable_flashmla": ServerArgs.enable_flashmla,
78+
"enable_nan_detection": ServerArgs.enable_nan_detection,
8079
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
80+
"max_micro_batch_size": ServerArgs.max_micro_batch_size,
8181
"moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
82-
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
8382
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
84-
"disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
83+
"sampling_backend": ServerArgs.sampling_backend,
84+
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
85+
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
86+
"torchao_config": ServerArgs.torchao_config,
87+
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
8588
}
8689

8790
logger = logging.getLogger(__name__)
@@ -709,6 +712,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
709712
# This is an optimization to reduce the overhead of the prefill check.
710713
batch_is_full: bool = False
711714

715+
# For chunked prefill in PP
716+
chunked_req: Optional[Req] = None
717+
712718
# Sampling info
713719
sampling_info: SamplingBatchInfo = None
714720
next_batch_sampling_info: SamplingBatchInfo = None
@@ -742,7 +748,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
742748
# For extend and mixed chunekd prefill
743749
prefix_lens: List[int] = None
744750
extend_lens: List[int] = None
745-
extend_num_tokens: int = None
751+
extend_num_tokens: Optional[int] = None
746752
decoding_reqs: List[Req] = None
747753
extend_logprob_start_lens: List[int] = None
748754
# It comes empty list if logprob is not required.
@@ -784,6 +790,7 @@ def init_new(
784790
enable_overlap: bool,
785791
spec_algorithm: SpeculativeAlgorithm,
786792
enable_custom_logit_processor: bool,
793+
chunked_req: Optional[Req] = None,
787794
):
788795
return_logprob = any(req.return_logprob for req in reqs)
789796

@@ -801,6 +808,7 @@ def init_new(
801808
spec_algorithm=spec_algorithm,
802809
enable_custom_logit_processor=enable_custom_logit_processor,
803810
return_hidden_states=any(req.return_hidden_states for req in reqs),
811+
chunked_req=chunked_req,
804812
)
805813

806814
def batch_size(self):
@@ -1217,7 +1225,7 @@ def check_decode_mem(self, buf_multiplier=1):
12171225

12181226
def retract_decode(self, server_args: ServerArgs):
12191227
"""Retract the decoding requests when there is not enough memory."""
1220-
sorted_indices = [i for i in range(len(self.reqs))]
1228+
sorted_indices = list(range(len(self.reqs)))
12211229

12221230
# TODO(lsyin): improve retraction policy for radix cache
12231231
# For spec decoding, filter_batch API can only filter
@@ -1394,15 +1402,19 @@ def prepare_for_decode(self):
13941402

13951403
def filter_batch(
13961404
self,
1397-
chunked_req_to_exclude: Optional[Req] = None,
1405+
chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
13981406
keep_indices: Optional[List[int]] = None,
13991407
):
14001408
if keep_indices is None:
1409+
if isinstance(chunked_req_to_exclude, Req):
1410+
chunked_req_to_exclude = [chunked_req_to_exclude]
1411+
elif chunked_req_to_exclude is None:
1412+
chunked_req_to_exclude = []
14011413
keep_indices = [
14021414
i
14031415
for i in range(len(self.reqs))
14041416
if not self.reqs[i].finished()
1405-
and self.reqs[i] is not chunked_req_to_exclude
1417+
and not self.reqs[i] in chunked_req_to_exclude
14061418
]
14071419

14081420
if keep_indices is None or len(keep_indices) == 0:

0 commit comments

Comments
 (0)