Skip to content

Commit 25c83ff

Browse files
ch-wanliusy58
andauthored
Performing Vocabulary Parallelism for LM Head across Attention TP Groups (#5558)
Co-authored-by: liusy58 <[email protected]>
1 parent 9f2c956 commit 25c83ff

File tree

8 files changed

+71
-23
lines changed

8 files changed

+71
-23
lines changed

docs/backend/server_arguments.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,4 @@ Please consult the documentation below and [server_args.py](https://github.com/s
221221
| `triton_attention_num_kv_splits` | Use to adjust the number of KV splits in triton kernels. | `8` |
222222
| `flashinfer_mla_disable_ragged` | Disable the use of the [ragged prefill](https://github.com/flashinfer-ai/flashinfer/blob/5751fc68f109877f6e0fc54f674cdcdef361af56/docs/tutorials/kv_layout.rst#L26) wrapper for the FlashInfer MLA attention backend. Ragged prefill increases throughput by computing MHA instead of paged MLA when there is no prefix match. Only use it when FlashInfer is being used as the MLA backend. | `False` |
223223
| `disable_chunked_prefix_cache` | Disable the use of chunked prefix cache for DeepSeek models. Only use it when FA3 is attention backend. | `False` |
224+
| `enable_dp_lm_head` | Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention. | `False` |

python/sglang/srt/layers/dp_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,12 +252,12 @@ def dp_scatter(
252252
)
253253

254254

255-
def tp_reduce_scatter(
255+
def attn_tp_reduce_scatter(
256256
output: torch.Tensor,
257257
input_list: List[torch.Tensor],
258258
):
259259
return get_attention_tp_group().reduce_scatter(output, input_list)
260260

261261

262-
def tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
262+
def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
263263
return get_attention_tp_group().all_gather(input_, tensor_list=output_list)

python/sglang/srt/layers/logits_processor.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,16 @@
2323
from torch import nn
2424

2525
from sglang.srt.distributed import (
26-
get_tensor_model_parallel_rank,
2726
get_tensor_model_parallel_world_size,
2827
tensor_model_parallel_all_gather,
2928
)
3029
from sglang.srt.layers.dp_attention import (
30+
attn_tp_all_gather,
3131
dp_gather_replicate,
3232
dp_scatter,
3333
get_attention_dp_rank,
3434
get_attention_dp_size,
35+
get_attention_tp_size,
3536
)
3637
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
3738
from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -198,12 +199,20 @@ def __init__(
198199
super().__init__()
199200
self.config = config
200201
self.logit_scale = logit_scale
201-
self.do_tensor_parallel_all_gather = (
202-
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
203-
)
204-
self.do_tensor_parallel_all_gather_dp_attn = (
205-
self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1
206-
)
202+
self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
203+
if self.use_attn_tp_group:
204+
self.attn_tp_size = get_attention_tp_size()
205+
self.do_tensor_parallel_all_gather = (
206+
not skip_all_gather and self.attn_tp_size > 1
207+
)
208+
self.do_tensor_parallel_all_gather_dp_attn = False
209+
else:
210+
self.do_tensor_parallel_all_gather = (
211+
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
212+
)
213+
self.do_tensor_parallel_all_gather_dp_attn = (
214+
self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1
215+
)
207216
self.final_logit_softcapping = getattr(
208217
self.config, "final_logit_softcapping", None
209218
)
@@ -442,7 +451,19 @@ def _get_logits(
442451
logits.mul_(self.logit_scale)
443452

444453
if self.do_tensor_parallel_all_gather:
445-
logits = tensor_model_parallel_all_gather(logits)
454+
if self.use_attn_tp_group:
455+
global_logits = torch.empty(
456+
(self.config.vocab_size, logits.shape[0]),
457+
device=logits.device,
458+
dtype=logits.dtype,
459+
)
460+
global_logits = global_logits.T
461+
attn_tp_all_gather(
462+
list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), logits
463+
)
464+
logits = global_logits
465+
else:
466+
logits = tensor_model_parallel_all_gather(logits)
446467

447468
if self.do_tensor_parallel_all_gather_dp_attn:
448469
logits, global_logits = (

python/sglang/srt/layers/vocab_parallel_embedding.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
get_tensor_model_parallel_world_size,
1414
tensor_model_parallel_all_reduce,
1515
)
16+
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
1617
from sglang.srt.layers.parameter import BasevLLMParameter
1718
from sglang.srt.layers.quantization.base_config import (
1819
QuantizationConfig,
@@ -214,22 +215,29 @@ def __init__(
214215
self,
215216
num_embeddings: int,
216217
embedding_dim: int,
218+
*,
217219
params_dtype: Optional[torch.dtype] = None,
218220
org_num_embeddings: Optional[int] = None,
219221
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
220222
quant_config: Optional[QuantizationConfig] = None,
221223
prefix: str = "",
222224
enable_tp: bool = True,
225+
use_attn_tp_group: bool = False,
223226
use_presharded_weights: bool = False,
224227
):
225228
super().__init__()
226229
self.quant_config = quant_config
227230

228231
self.enable_tp = enable_tp
229232
if self.enable_tp:
230-
tp_rank = get_tensor_model_parallel_rank()
231-
self.tp_size = get_tensor_model_parallel_world_size()
233+
if use_attn_tp_group:
234+
tp_rank = get_attention_tp_rank()
235+
self.tp_size = get_attention_tp_size()
236+
else:
237+
tp_rank = get_tensor_model_parallel_rank()
238+
self.tp_size = get_tensor_model_parallel_world_size()
232239
else:
240+
assert use_attn_tp_group is False
233241
tp_rank = 0
234242
self.tp_size = 1
235243

@@ -519,22 +527,25 @@ def __init__(
519527
self,
520528
num_embeddings: int,
521529
embedding_dim: int,
530+
*,
522531
bias: bool = False,
523532
params_dtype: Optional[torch.dtype] = None,
524533
org_num_embeddings: Optional[int] = None,
525534
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
526535
quant_config: Optional[QuantizationConfig] = None,
527536
prefix: str = "",
537+
use_attn_tp_group: bool = False,
528538
use_presharded_weights: bool = False,
529539
):
530540
super().__init__(
531541
num_embeddings,
532542
embedding_dim,
533-
params_dtype,
534-
org_num_embeddings,
535-
padding_size,
536-
quant_config,
537-
prefix,
543+
params_dtype=params_dtype,
544+
org_num_embeddings=org_num_embeddings,
545+
padding_size=padding_size,
546+
quant_config=quant_config,
547+
prefix=prefix,
548+
use_attn_tp_group=use_attn_tp_group,
538549
use_presharded_weights=use_presharded_weights,
539550
)
540551
self.quant_config = quant_config

python/sglang/srt/managers/schedule_batch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
"disable_radix_cache": ServerArgs.disable_radix_cache,
7575
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
7676
"enable_dp_attention": ServerArgs.enable_dp_attention,
77+
"enable_dp_lm_head": ServerArgs.enable_dp_lm_head,
7778
"enable_ep_moe": ServerArgs.enable_ep_moe,
7879
"enable_nan_detection": ServerArgs.enable_nan_detection,
7980
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,

python/sglang/srt/models/deepseek_v2.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@
3636
)
3737
from sglang.srt.layers.activation import SiluAndMul
3838
from sglang.srt.layers.dp_attention import (
39+
attn_tp_all_gather,
40+
attn_tp_reduce_scatter,
3941
dp_gather_partial,
4042
dp_scatter,
4143
get_attention_dp_size,
4244
get_attention_tp_rank,
4345
get_attention_tp_size,
44-
tp_all_gather,
45-
tp_reduce_scatter,
4646
)
4747
from sglang.srt.layers.layernorm import RMSNorm
4848
from sglang.srt.layers.linear import (
@@ -1323,7 +1323,7 @@ def forward_ffn_with_scattered_input(
13231323
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
13241324
hidden_states,
13251325
)
1326-
tp_all_gather(
1326+
attn_tp_all_gather(
13271327
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
13281328
)
13291329

@@ -1339,7 +1339,7 @@ def forward_ffn_with_scattered_input(
13391339
if self.input_is_scattered:
13401340
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
13411341
hidden_states = tensor_list[self.attn_tp_rank]
1342-
tp_reduce_scatter(hidden_states, tensor_list)
1342+
attn_tp_reduce_scatter(hidden_states, tensor_list)
13431343
if hidden_states.shape[0] != 0:
13441344
hidden_states, residual = self.post_attention_layernorm(
13451345
hidden_states, residual
@@ -1349,7 +1349,7 @@ def forward_ffn_with_scattered_input(
13491349
hidden_states += residual
13501350
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
13511351
hidden_states = tensor_list[self.attn_tp_rank]
1352-
tp_reduce_scatter(hidden_states, tensor_list)
1352+
attn_tp_reduce_scatter(hidden_states, tensor_list)
13531353
residual = hidden_states
13541354
if hidden_states.shape[0] != 0:
13551355
hidden_states = self.post_attention_layernorm(hidden_states)
@@ -1373,7 +1373,7 @@ def forward_ffn_with_scattered_input(
13731373
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
13741374
hidden_states,
13751375
)
1376-
tp_all_gather(
1376+
attn_tp_all_gather(
13771377
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
13781378
)
13791379

@@ -1475,6 +1475,7 @@ def __init__(
14751475
config.hidden_size,
14761476
quant_config=quant_config,
14771477
prefix=add_prefix("lm_head", prefix),
1478+
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
14781479
)
14791480
self.logits_processor = LogitsProcessor(config)
14801481
self.dp_size = get_attention_dp_size()

python/sglang/srt/models/llama.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
ParallelLMHead,
4646
VocabParallelEmbedding,
4747
)
48+
from sglang.srt.managers.schedule_batch import global_server_args_dict
4849
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
4950
from sglang.srt.model_loader.weight_utils import (
5051
default_weight_loader,
@@ -420,6 +421,7 @@ def __init__(
420421
config.hidden_size,
421422
quant_config=quant_config,
422423
prefix=add_prefix("lm_head", prefix),
424+
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
423425
)
424426
self.logits_processor = LogitsProcessor(config)
425427
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)

python/sglang/srt/server_args.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ class ServerArgs:
159159
disable_overlap_schedule: bool = False
160160
enable_mixed_chunk: bool = False
161161
enable_dp_attention: bool = False
162+
enable_dp_lm_head: bool = False
162163
enable_ep_moe: bool = False
163164
enable_deepep_moe: bool = False
164165
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
@@ -323,6 +324,11 @@ def __post_init__(self):
323324
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
324325
)
325326

327+
if self.enable_dp_lm_head:
328+
assert (
329+
self.enable_dp_attention
330+
), "Please enable dp attention when setting enable_dp_attention. "
331+
326332
# DeepEP MoE
327333
self.enable_sp_layernorm = False
328334
if self.enable_deepep_moe:
@@ -1055,6 +1061,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
10551061
action="store_true",
10561062
help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
10571063
)
1064+
parser.add_argument(
1065+
"--enable-dp-lm-head",
1066+
action="store_true",
1067+
help="Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention.",
1068+
)
10581069
parser.add_argument(
10591070
"--enable-ep-moe",
10601071
action="store_true",

0 commit comments

Comments
 (0)