Skip to content

Commit bfb8ef7

Browse files
laixinnsleepcoo
authored andcommitted
[qwen3] support qwen3 ep moe (sgl-project#5917)
Co-authored-by: sleepcoo <[email protected]>
1 parent e462c2c commit bfb8ef7

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

python/sglang/srt/models/qwen2_moe.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
RowParallelLinear,
3737
)
3838
from sglang.srt.layers.logits_processor import LogitsProcessor
39+
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
3940
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
4041
from sglang.srt.layers.quantization.base_config import QuantizationConfig
4142
from sglang.srt.layers.radix_attention import RadixAttention
@@ -45,6 +46,7 @@
4546
VocabParallelEmbedding,
4647
)
4748
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
49+
from sglang.srt.managers.schedule_batch import global_server_args_dict
4850
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
4951
from sglang.srt.model_loader.weight_utils import default_weight_loader
5052
from sglang.srt.utils import add_prefix, make_layers
@@ -108,12 +110,13 @@ def __init__(
108110
f"the number of experts {config.num_experts}."
109111
)
110112

111-
self.experts = FusedMoE(
113+
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
114+
115+
self.experts = MoEImpl(
112116
num_experts=config.num_experts,
113117
top_k=config.num_experts_per_tok,
114118
hidden_size=config.hidden_size,
115119
intermediate_size=config.moe_intermediate_size,
116-
reduce_results=False,
117120
renormalize=config.norm_topk_prob,
118121
quant_config=quant_config,
119122
prefix=add_prefix("experts", prefix),
@@ -427,7 +430,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
427430
("gate_up_proj", "up_proj", 1),
428431
]
429432

430-
expert_params_mapping = FusedMoE.make_expert_params_mapping(
433+
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
434+
435+
expert_params_mapping = MoEImpl.make_expert_params_mapping(
431436
ckpt_gate_proj_name="gate_proj",
432437
ckpt_down_proj_name="down_proj",
433438
ckpt_up_proj_name="up_proj",

python/sglang/srt/models/qwen3_moe.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
RowParallelLinear,
4141
)
4242
from sglang.srt.layers.logits_processor import LogitsProcessor
43+
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
4344
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
4445
from sglang.srt.layers.quantization.base_config import QuantizationConfig
4546
from sglang.srt.layers.radix_attention import RadixAttention
@@ -48,6 +49,7 @@
4849
ParallelLMHead,
4950
VocabParallelEmbedding,
5051
)
52+
from sglang.srt.managers.schedule_batch import global_server_args_dict
5153
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
5254
from sglang.srt.model_loader.weight_utils import default_weight_loader
5355
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
@@ -73,12 +75,13 @@ def __init__(
7375
f"the number of experts {config.num_experts}."
7476
)
7577

76-
self.experts = FusedMoE(
78+
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
79+
80+
self.experts = MoEImpl(
7781
num_experts=config.num_experts,
7882
top_k=config.num_experts_per_tok,
7983
hidden_size=config.hidden_size,
8084
intermediate_size=config.moe_intermediate_size,
81-
reduce_results=False,
8285
renormalize=config.norm_topk_prob,
8386
quant_config=quant_config,
8487
prefix=add_prefix("experts", prefix),
@@ -356,7 +359,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
356359
("gate_up_proj", "up_proj", 1),
357360
]
358361

359-
expert_params_mapping = FusedMoE.make_expert_params_mapping(
362+
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
363+
364+
expert_params_mapping = MoEImpl.make_expert_params_mapping(
360365
ckpt_gate_proj_name="gate_proj",
361366
ckpt_down_proj_name="down_proj",
362367
ckpt_up_proj_name="up_proj",

0 commit comments

Comments
 (0)