Skip to content

Commit ba22024

Browse files
committed
[Bugfix] Fix MoE flatten_tp_size unconditionally including dp_size
Split flatten_tp_across_dp_and_pcp into separate EP and non-EP paths so DP ranks hold replicated MoE weights instead of being folded into TP. Signed-off-by: AjAnubolu <anuboluajay@gmail.com>
1 parent 10f08de commit ba22024

1 file changed

Lines changed: 50 additions & 30 deletions

File tree

  • vllm/model_executor/layers/fused_moe

vllm/model_executor/layers/fused_moe/config.py

Lines changed: 50 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,51 +1015,64 @@ def make(
10151015
Expert Parallelism is considered only when either `dp_size_`, `pcp_size_` or
10161016
`tp_size_` is non trivial.
10171017
1018-
Note that PCP serves the same function as DP here.
1018+
Note that PCP serves the same function as TP here (PCP ranks
1019+
are flattened into TP), while DP ranks hold replicated MoE
1020+
weights when EP is disabled.
10191021
1020-
When TP = 2, DP(PCP) = 1 and EP = False, the configuration on different
1021-
devices:
1022+
When TP = 2, DP = 1, PCP = 1 and EP = False, the configuration
1023+
on different devices:
10221024
10231025
- device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} //
10241026
legend : {size, rank}
10251027
- device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0}
10261028
- Comment : Tensors are sharded across 2 devices.
10271029
1028-
When TP = 1, DP(PCP) = 2 and EP = False, the configuration on different
1029-
devices:
1030-
1031-
- device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0}
1032-
- device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0}
1033-
- Comment: There are 2 engine instances and the tensors are sharded
1034-
across 2 decvices.
1035-
1036-
When TP = 2, DP(PCP) = 2 and EP = False, the configuration on different
1037-
devices:
1038-
1039-
- device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0}
1040-
- device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0}
1041-
- device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0}
1042-
- device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0}
1043-
- Comment: There are 2 engine instances and the tensors are sharded
1030+
When TP = 1, DP = 2, PCP = 1 and EP = False, the configuration
1031+
on different devices:
1032+
1033+
- device 0 : TP = {1, 0} DP = {2, 0} EP = {1, 0}
1034+
- device 1 : TP = {1, 0} DP = {2, 1} EP = {1, 0}
1035+
- Comment: There are 2 engine instances. Each DP rank holds
1036+
a full replica of the MoE weights.
1037+
1038+
When TP = 2, DP = 2, PCP = 1 and EP = False, the configuration
1039+
on different devices:
1040+
1041+
- device 0: TP = {2, 0} DP = {2, 0} EP = {1, 0}
1042+
- device 1: TP = {2, 1} DP = {2, 0} EP = {1, 0}
1043+
- device 2: TP = {2, 0} DP = {2, 1} EP = {1, 0}
1044+
- device 3: TP = {2, 1} DP = {2, 1} EP = {1, 0}
1045+
- Comment: There are 2 engine instances and the tensors are
1046+
sharded across 2 devices within each DP group. Each DP
1047+
rank holds the same weight shards.
1048+
1049+
When TP = 2, DP = 1, PCP = 2 and EP = False, the configuration
1050+
on different devices:
1051+
1052+
- device 0: TP = {4, 0} DP = {1, 0} EP = {1, 0}
1053+
- device 1: TP = {4, 1} DP = {1, 0} EP = {1, 0}
1054+
- device 2: TP = {4, 2} DP = {1, 0} EP = {1, 0}
1055+
- device 3: TP = {4, 3} DP = {1, 0} EP = {1, 0}
1056+
- Comment: PCP is flattened into TP. Tensors are sharded
10441057
across 4 devices.
10451058
1046-
When, TP = 2, DP(PCP) = 1 and EP = True, the configuration on different
1047-
devices:
1059+
When, TP = 2, DP = 1, PCP = 1 and EP = True, the configuration
1060+
on different devices:
10481061
10491062
- device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0}
10501063
- device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1}
10511064
- Comment: The experts are split between the 2 devices.
10521065
1053-
When, TP = 1, DP(PCP) = 2 and EP = True, the configuration on different
1054-
devices:
1066+
When, TP = 1, DP = 2, PCP = 1 and EP = True, the configuration
1067+
on different devices:
10551068
10561069
- device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0}
10571070
- device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1}
10581071
- Comment: There are 2 engine instances and the experts are split
10591072
between the 2 devices.
10601073
1061-
When TP = 2, DP(PCP) = 2 and EP = True, the configuration on different
1062-
devices:
1074+
When TP = 2, DP = 2, PCP = 1 and EP = True, the configuration
1075+
on different devices:
10631076
10641077
- device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0}
10651078
- device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1}
@@ -1077,11 +1090,14 @@ def make(
10771090
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
10781091
pcp_size = pcp_size_
10791092
pcp_rank = get_pcp_group().rank_in_group if pcp_size > 1 else 0
1080-
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
1081-
tp_size_, dp_size_, dp_rank, pcp_size_, pcp_rank
1082-
)
10831093

10841094
if not use_ep:
1095+
# For non-EP, only flatten PCP into TP. DP ranks hold
1096+
# replicated MoE weights and process different data
1097+
# independently, so dp_size must NOT be folded into tp_size.
1098+
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
1099+
tp_size_, 1, 0, pcp_size_, pcp_rank
1100+
)
10851101
return FusedMoEParallelConfig(
10861102
tp_size=tp_size,
10871103
tp_rank=tp_rank,
@@ -1098,8 +1114,12 @@ def make(
10981114
)
10991115
# DP + EP / TP + EP / DP + TP + EP
11001116
assert use_ep
1101-
# In EP, each device owns a set of experts fully. There is no tensor
1102-
# parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that.
1117+
# For EP, flatten DP, PCP and TP into EP size since each device
1118+
# owns a distinct set of experts fully (no tensor parallelism
1119+
# within experts).
1120+
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
1121+
tp_size_, dp_size_, dp_rank, pcp_size_, pcp_rank
1122+
)
11031123
ep_size = tp_size
11041124
ep_rank = tp_rank
11051125
return FusedMoEParallelConfig(

0 commit comments

Comments
 (0)