Skip to content

Commit 11383ce

Browse files
authored
[PP] Add pipeline parallelism (#5724)
1 parent e97e57e commit 11383ce

25 files changed

+1149
-307
lines changed

python/sglang/bench_one_batch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ def load_model(server_args, port_args, tp_rank):
154154
gpu_id=tp_rank,
155155
tp_rank=tp_rank,
156156
tp_size=server_args.tp_size,
157+
pp_rank=0,
158+
pp_size=1,
157159
nccl_port=port_args.nccl_port,
158160
server_args=server_args,
159161
)

python/sglang/srt/entrypoints/engine.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ def __init__(self, **kwargs):
126126
server_args=server_args,
127127
port_args=port_args,
128128
)
129-
130129
self.server_args = server_args
131130
self.tokenizer_manager = tokenizer_manager
132131
self.scheduler_info = scheduler_info
@@ -301,7 +300,6 @@ def get_server_info(self):
301300
internal_states = loop.run_until_complete(
302301
self.tokenizer_manager.get_internal_state()
303302
)
304-
305303
return {
306304
**dataclasses.asdict(self.tokenizer_manager.server_args),
307305
**self.scheduler_info,
@@ -520,25 +518,44 @@ def _launch_subprocesses(
520518
)
521519

522520
scheduler_pipe_readers = []
523-
tp_size_per_node = server_args.tp_size // server_args.nnodes
521+
522+
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
523+
tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
524524
tp_rank_range = range(
525-
tp_size_per_node * server_args.node_rank,
526-
tp_size_per_node * (server_args.node_rank + 1),
525+
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
526+
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
527527
)
528-
for tp_rank in tp_rank_range:
529-
reader, writer = mp.Pipe(duplex=False)
530-
gpu_id = (
531-
server_args.base_gpu_id
532-
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
533-
)
534-
proc = mp.Process(
535-
target=run_scheduler_process,
536-
args=(server_args, port_args, gpu_id, tp_rank, None, writer),
537-
)
538-
with memory_saver_adapter.configure_subprocess():
539-
proc.start()
540-
scheduler_procs.append(proc)
541-
scheduler_pipe_readers.append(reader)
528+
529+
pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
530+
pp_rank_range = range(
531+
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
532+
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
533+
)
534+
535+
for pp_rank in pp_rank_range:
536+
for tp_rank in tp_rank_range:
537+
reader, writer = mp.Pipe(duplex=False)
538+
gpu_id = (
539+
server_args.base_gpu_id
540+
+ ((pp_rank % pp_size_per_node) * tp_size_per_node)
541+
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
542+
)
543+
proc = mp.Process(
544+
target=run_scheduler_process,
545+
args=(
546+
server_args,
547+
port_args,
548+
gpu_id,
549+
tp_rank,
550+
pp_rank,
551+
None,
552+
writer,
553+
),
554+
)
555+
with memory_saver_adapter.configure_subprocess():
556+
proc.start()
557+
scheduler_procs.append(proc)
558+
scheduler_pipe_readers.append(reader)
542559
else:
543560
# Launch the data parallel controller
544561
reader, writer = mp.Pipe(duplex=False)

python/sglang/srt/layers/dp_attention.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def initialize_dp_attention(
4343
tp_rank: int,
4444
tp_size: int,
4545
dp_size: int,
46+
pp_size: int,
4647
):
4748
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
4849

@@ -53,17 +54,19 @@ def initialize_dp_attention(
5354
)
5455

5556
if enable_dp_attention:
57+
local_rank = tp_rank % (tp_size // dp_size)
5658
_DP_SIZE = dp_size
5759
else:
60+
local_rank = tp_rank
5861
_DP_SIZE = 1
5962

6063
tp_group = get_tp_group()
6164
_ATTN_TP_GROUP = GroupCoordinator(
6265
[
6366
list(range(head, head + _ATTN_TP_SIZE))
64-
for head in range(0, tp_size, _ATTN_TP_SIZE)
67+
for head in range(0, pp_size * tp_size, _ATTN_TP_SIZE)
6568
],
66-
tp_group.local_rank,
69+
local_rank,
6770
torch.distributed.get_backend(tp_group.device_group),
6871
SYNC_TOKEN_IDS_ACROSS_TP,
6972
False,

python/sglang/srt/layers/utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import logging
2+
import re
3+
4+
import torch
5+
6+
logger = logging.getLogger(__name__)
7+
8+
9+
def get_layer_id(weight_name):
10+
# example weight name: model.layers.10.self_attn.qkv_proj.weight
11+
match = re.search(r"layers\.(\d+)\.", weight_name)
12+
if match:
13+
return int(match.group(1))
14+
return None
15+
16+
17+
class PPMissingLayer(torch.nn.Identity):
18+
# Adapted from
19+
# https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1
20+
"""
21+
A placeholder layer for missing layers in a pipeline parallel model.
22+
"""
23+
24+
def __init__(self, *args, **kwargs):
25+
super().__init__()
26+
self.return_tuple = kwargs.get("return_tuple", False)
27+
28+
def forward(self, *args, **kwargs):
29+
"""
30+
Return the first arg from args or the first value from kwargs.
31+
32+
Wraps the input in a tuple if `self.return_tuple` is True.
33+
"""
34+
input = args[0] if args else next(iter(kwargs.values()))
35+
return (input,) if self.return_tuple else input

python/sglang/srt/managers/data_parallel_controller.py

Lines changed: 52 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -181,44 +181,62 @@ def launch_tensor_parallel_group(
181181
enable=server_args.enable_memory_saver
182182
)
183183

184-
# Launch tensor parallel scheduler processes
185184
scheduler_pipe_readers = []
186-
tp_size_per_node = server_args.tp_size // server_args.nnodes
185+
186+
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
187+
tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
187188
tp_rank_range = range(
188-
tp_size_per_node * server_args.node_rank,
189-
tp_size_per_node * (server_args.node_rank + 1),
189+
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
190+
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
191+
)
192+
193+
pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
194+
pp_rank_range = range(
195+
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
196+
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
190197
)
191-
for tp_rank in tp_rank_range:
192-
rank_port_args = port_args
193-
194-
if server_args.enable_dp_attention:
195-
# dp attention has different sharding logic
196-
_, _, dp_rank = compute_dp_attention_world_info(
197-
server_args.enable_dp_attention,
198-
tp_rank,
199-
server_args.tp_size,
200-
server_args.dp_size,
198+
199+
for pp_rank in pp_rank_range:
200+
for tp_rank in tp_rank_range:
201+
rank_port_args = port_args
202+
203+
if server_args.enable_dp_attention:
204+
# dp attention has different sharding logic
205+
_, _, dp_rank = compute_dp_attention_world_info(
206+
server_args.enable_dp_attention,
207+
tp_rank,
208+
server_args.tp_size,
209+
server_args.dp_size,
210+
)
211+
# compute zmq ports for this dp rank
212+
rank_port_args = PortArgs.init_new(server_args, dp_rank)
213+
# Data parallelism resues the tensor parallelism group,
214+
# so all dp ranks should use the same nccl port.
215+
rank_port_args.nccl_port = port_args.nccl_port
216+
217+
reader, writer = mp.Pipe(duplex=False)
218+
gpu_id = (
219+
server_args.base_gpu_id
220+
+ base_gpu_id
221+
+ ((pp_rank % pp_size_per_node) * tp_size_per_node)
222+
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
201223
)
202-
# compute zmq ports for this dp rank
203-
rank_port_args = PortArgs.init_new(server_args, dp_rank)
204-
# Data parallelism resues the tensor parallelism group,
205-
# so all dp ranks should use the same nccl port.
206-
rank_port_args.nccl_port = port_args.nccl_port
207-
208-
reader, writer = mp.Pipe(duplex=False)
209-
gpu_id = (
210-
server_args.base_gpu_id
211-
+ base_gpu_id
212-
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
213-
)
214-
proc = mp.Process(
215-
target=run_scheduler_process,
216-
args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),
217-
)
218-
with memory_saver_adapter.configure_subprocess():
219-
proc.start()
220-
self.scheduler_procs.append(proc)
221-
scheduler_pipe_readers.append(reader)
224+
proc = mp.Process(
225+
target=run_scheduler_process,
226+
args=(
227+
server_args,
228+
rank_port_args,
229+
gpu_id,
230+
tp_rank,
231+
pp_rank,
232+
dp_rank,
233+
writer,
234+
),
235+
)
236+
with memory_saver_adapter.configure_subprocess():
237+
proc.start()
238+
self.scheduler_procs.append(proc)
239+
scheduler_pipe_readers.append(reader)
222240

223241
# Wait for model to finish loading
224242
scheduler_info = []

python/sglang/srt/managers/schedule_batch.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -66,23 +66,24 @@
6666
# Put some global args for easy access
6767
global_server_args_dict = {
6868
"attention_backend": ServerArgs.attention_backend,
69-
"sampling_backend": ServerArgs.sampling_backend,
70-
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
71-
"torchao_config": ServerArgs.torchao_config,
72-
"enable_nan_detection": ServerArgs.enable_nan_detection,
73-
"enable_dp_attention": ServerArgs.enable_dp_attention,
74-
"enable_ep_moe": ServerArgs.enable_ep_moe,
75-
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
69+
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
7670
"deepep_mode": ServerArgs.deepep_mode,
7771
"device": ServerArgs.device,
78-
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
79-
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
72+
"disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
8073
"disable_radix_cache": ServerArgs.disable_radix_cache,
74+
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
75+
"enable_dp_attention": ServerArgs.enable_dp_attention,
76+
"enable_ep_moe": ServerArgs.enable_ep_moe,
77+
"enable_nan_detection": ServerArgs.enable_nan_detection,
8178
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
79+
"max_micro_batch_size": ServerArgs.max_micro_batch_size,
8280
"moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
83-
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
8481
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
85-
"disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
82+
"sampling_backend": ServerArgs.sampling_backend,
83+
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
84+
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
85+
"torchao_config": ServerArgs.torchao_config,
86+
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
8687
}
8788

8889
logger = logging.getLogger(__name__)
@@ -728,6 +729,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
728729
# Events
729730
launch_done: Optional[threading.Event] = None
730731

732+
# For chunked prefill in PP
733+
chunked_req: Optional[Req] = None
734+
731735
# Sampling info
732736
sampling_info: SamplingBatchInfo = None
733737
next_batch_sampling_info: SamplingBatchInfo = None
@@ -761,7 +765,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
761765
# For extend and mixed chunekd prefill
762766
prefix_lens: List[int] = None
763767
extend_lens: List[int] = None
764-
extend_num_tokens: int = None
768+
extend_num_tokens: Optional[int] = None
765769
decoding_reqs: List[Req] = None
766770
extend_logprob_start_lens: List[int] = None
767771
# It comes empty list if logprob is not required.
@@ -803,6 +807,7 @@ def init_new(
803807
enable_overlap: bool,
804808
spec_algorithm: SpeculativeAlgorithm,
805809
enable_custom_logit_processor: bool,
810+
chunked_req: Optional[Req] = None,
806811
):
807812
return_logprob = any(req.return_logprob for req in reqs)
808813

@@ -820,6 +825,7 @@ def init_new(
820825
spec_algorithm=spec_algorithm,
821826
enable_custom_logit_processor=enable_custom_logit_processor,
822827
return_hidden_states=any(req.return_hidden_states for req in reqs),
828+
chunked_req=chunked_req,
823829
)
824830

825831
def batch_size(self):
@@ -1236,7 +1242,7 @@ def check_decode_mem(self, buf_multiplier=1):
12361242

12371243
def retract_decode(self, server_args: ServerArgs):
12381244
"""Retract the decoding requests when there is not enough memory."""
1239-
sorted_indices = [i for i in range(len(self.reqs))]
1245+
sorted_indices = list(range(len(self.reqs)))
12401246

12411247
# TODO(lsyin): improve retraction policy for radix cache
12421248
# For spec decoding, filter_batch API can only filter
@@ -1413,15 +1419,19 @@ def prepare_for_decode(self):
14131419

14141420
def filter_batch(
14151421
self,
1416-
chunked_req_to_exclude: Optional[Req] = None,
1422+
chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
14171423
keep_indices: Optional[List[int]] = None,
14181424
):
14191425
if keep_indices is None:
1426+
if isinstance(chunked_req_to_exclude, Req):
1427+
chunked_req_to_exclude = [chunked_req_to_exclude]
1428+
elif chunked_req_to_exclude is None:
1429+
chunked_req_to_exclude = []
14201430
keep_indices = [
14211431
i
14221432
for i in range(len(self.reqs))
14231433
if not self.reqs[i].finished()
1424-
and self.reqs[i] is not chunked_req_to_exclude
1434+
and not self.reqs[i] in chunked_req_to_exclude
14251435
]
14261436

14271437
if keep_indices is None or len(keep_indices) == 0:

0 commit comments

Comments
 (0)