Skip to content

Commit 9419453

Browse files
authored
[AMD] Add MoE weights and scales padding (#18684)
1 parent f97c09d commit 9419453

8 files changed

Lines changed: 131 additions & 36 deletions

File tree

python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
from __future__ import annotations
77

88
import functools
9-
import os
109
from typing import TYPE_CHECKING, List, Optional
1110

1211
import torch
1312
import torch.nn.functional as F
1413
import triton.language as tl
1514

1615
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
16+
from sglang.srt.layers.moe.utils import get_moe_padding_size
1717
from sglang.srt.utils import (
1818
cpu_has_amx_support,
1919
get_bool_env_var,
@@ -75,7 +75,7 @@
7575
# Fallback: vllm not available, will use native PyTorch implementations
7676
_has_vllm_ops = False
7777

78-
padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
78+
padding_size = get_moe_padding_size(_use_aiter)
7979

8080

8181
@register_custom_op(mutates_args=["hidden_states"])

python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import functools
4-
import os
54
from collections import OrderedDict
65
from typing import Any, Dict, List, Optional
76

@@ -10,6 +9,7 @@
109
import triton.language as tl
1110

1211
from sglang.srt.batch_invariant_ops import is_batch_invariant_mode_enabled
12+
from sglang.srt.layers.moe.utils import get_moe_padding_size
1313
from sglang.srt.layers.quantization.fp8_kernel import (
1414
per_token_group_quant_fp8,
1515
scaled_fp8_quant,
@@ -49,7 +49,7 @@
4949
elif _is_hip:
5050
pass
5151

52-
padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
52+
padding_size = get_moe_padding_size(_use_aiter)
5353

5454

5555
def support_tensor_descriptor():

python/sglang/srt/layers/moe/fused_moe_triton/layer.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -440,8 +440,14 @@ def _load_w13(
440440
# Use narrow_padded_param_and_loaded_weight for:
441441
# 1. CPU (always)
442442
# 2. GPU with flashinfer_trtllm padding (when intermediate_size is padded to 128)
443+
# 3. GPU with Aiter padding
443444
# This handles the case where the loaded weights are smaller than the padded expert_data
444-
use_padded_loading = _is_cpu or self.use_flashinfer_trtllm_moe
445+
aiter_padded = (
446+
_use_aiter
447+
and hasattr(self, "w2_weight")
448+
and getattr(self.w2_weight, "weight_padded", False)
449+
)
450+
use_padded_loading = _is_cpu or self.use_flashinfer_trtllm_moe or aiter_padded
445451
if use_padded_loading:
446452
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
447453
expert_data,
@@ -514,8 +520,14 @@ def _load_w2(
514520
# Use narrow_padded_param_and_loaded_weight for:
515521
# 1. CPU (always)
516522
# 2. GPU with flashinfer_trtllm padding (when intermediate_size is padded to 128)
523+
# 3. GPU with Aiter padding
517524
# This handles the case where the loaded weights are smaller than the padded expert_data
518-
use_padded_loading = _is_cpu or self.use_flashinfer_trtllm_moe
525+
aiter_padded = (
526+
_use_aiter
527+
and hasattr(self, "w2_weight")
528+
and getattr(self.w2_weight, "weight_padded", False)
529+
)
530+
use_padded_loading = _is_cpu or self.use_flashinfer_trtllm_moe or aiter_padded
519531
if use_padded_loading:
520532
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
521533
expert_data,

python/sglang/srt/layers/moe/utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from __future__ import annotations
22

33
import logging
4+
import os
45
from contextlib import contextmanager
56
from enum import Enum, IntEnum
67
from typing import TYPE_CHECKING, Optional
78

9+
import torch
10+
811
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
912
from sglang.srt.layers.dp_attention import (
1013
get_attention_dp_size,
@@ -341,3 +344,30 @@ class RoutingMethodType(IntEnum):
341344
TopK = (5,)
342345
# Unspecified
343346
Unspecified = 6
347+
348+
349+
def get_moe_padding_size(is_aiter_moe):
350+
if is_aiter_moe:
351+
return 128
352+
else:
353+
return 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
354+
355+
356+
def get_moe_weight_sizes(inter_dim, is_concat, is_packed, is_aiter_moe):
357+
w13_up_dim = 2 * inter_dim if is_concat else inter_dim
358+
w2_down_dim = inter_dim // 2 if is_packed else inter_dim
359+
360+
if is_aiter_moe:
361+
padding_size = get_moe_padding_size(True)
362+
align_aiter = lambda n: ((n + padding_size - 1) // padding_size) * padding_size
363+
is_padded = (w2_down_dim % padding_size) > 0
364+
if is_padded:
365+
w2_down_dim = align_aiter(w2_down_dim)
366+
# up proj + gate fusion : 2x
367+
if is_concat:
368+
w13_up_dim = w2_down_dim * 2
369+
# packed
370+
if hasattr(torch, "float4_e2m1fn_x2") and is_packed:
371+
w13_up_dim *= 2
372+
373+
return (w13_up_dim, w2_down_dim, False if not is_aiter_moe else is_padded)

python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8_moe.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
FlashInferTrtllmFp8MoeQuantInfo,
1313
)
1414
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
15-
from sglang.srt.layers.moe.utils import get_moe_runner_backend
15+
from sglang.srt.layers.moe.utils import (
16+
get_moe_runner_backend,
17+
get_moe_weight_sizes,
18+
)
1619
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
1720
CompressedTensorsMoEScheme,
1821
)
@@ -120,11 +123,22 @@ def create_weights(
120123
f"weight quantization block_k = {block_k}."
121124
)
122125

126+
w13_up_dim, w2_down_dim, weight_padded = get_moe_weight_sizes(
127+
intermediate_size_per_partition,
128+
is_aiter_moe=True,
129+
is_concat=True,
130+
is_packed=False,
131+
)
132+
133+
extra_weight_attrs.update(
134+
{"weight_padded": weight_padded},
135+
)
136+
123137
# WEIGHTS
124138
w13_weight = torch.nn.Parameter(
125139
torch.empty(
126140
num_experts,
127-
2 * intermediate_size_per_partition,
141+
w13_up_dim,
128142
hidden_size,
129143
dtype=params_dtype,
130144
),
@@ -137,7 +151,7 @@ def create_weights(
137151
torch.empty(
138152
num_experts,
139153
hidden_size,
140-
intermediate_size_per_partition,
154+
w2_down_dim,
141155
dtype=params_dtype,
142156
),
143157
requires_grad=False,
@@ -161,7 +175,7 @@ def create_weights(
161175
w13_weight_scale = torch.nn.Parameter(
162176
torch.ones(
163177
num_experts,
164-
2 * intermediate_size_per_partition,
178+
w13_up_dim,
165179
1,
166180
dtype=torch.float32,
167181
),

python/sglang/srt/layers/quantization/fp8.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@
2626
FlashInferTrtllmFp8MoeQuantInfo,
2727
)
2828
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
29-
from sglang.srt.layers.moe.utils import RoutingMethodType, get_moe_runner_backend
29+
from sglang.srt.layers.moe.utils import (
30+
RoutingMethodType,
31+
get_moe_padding_size,
32+
get_moe_runner_backend,
33+
get_moe_weight_sizes,
34+
)
3035
from sglang.srt.layers.parameter import (
3136
BlockQuantScaleParameter,
3237
ModelWeightParameter,
@@ -778,27 +783,38 @@ def create_weights(
778783
if self.quant_config.is_checkpoint_fp8_serialized:
779784
params_dtype = torch.uint32 if _use_hip_int4 else torch.float8_e4m3fn
780785
tp_size = get_tensor_model_parallel_world_size()
786+
787+
w13_up_dim, w2_up_dim, weight_padded = get_moe_weight_sizes(
788+
intermediate_size_per_partition,
789+
is_aiter_moe=True,
790+
is_concat=True,
791+
is_packed=False,
792+
)
793+
781794
if self.block_quant:
782795
block_n, block_k = (
783796
self.quant_config.weight_block_size[0],
784797
self.quant_config.weight_block_size[1],
785798
)
786-
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
787-
# Required by column parallel or enabling merged weights
788-
if intermediate_size_per_partition % block_n != 0:
789-
raise ValueError(
790-
f"The output_size of gate's and up's weight = "
791-
f"{intermediate_size_per_partition} is not divisible by "
792-
f"weight quantization block_n = {block_n}."
793-
)
794-
if tp_size > 1:
795-
# Required by row parallel
796-
if intermediate_size_per_partition % block_k != 0:
799+
800+
padding_size = get_moe_padding_size(_use_aiter)
801+
if not (_use_aiter and padding_size == block_n == block_k):
802+
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
803+
# Required by column parallel or enabling merged weights
804+
if intermediate_size_per_partition % block_n != 0:
797805
raise ValueError(
798-
f"The input_size of down's weight = "
806+
f"The output_size of gate's and up's weight = "
799807
f"{intermediate_size_per_partition} is not divisible by "
800-
f"weight quantization block_k = {block_k}."
808+
f"weight quantization block_n = {block_n}."
801809
)
810+
if tp_size > 1:
811+
# Required by row parallel
812+
if intermediate_size_per_partition % block_k != 0:
813+
raise ValueError(
814+
f"The input_size of down's weight = "
815+
f"{intermediate_size_per_partition} is not divisible by "
816+
f"weight quantization block_k = {block_k}."
817+
)
802818

803819
# WEIGHTS
804820
if _is_hip and _use_hip_int4:
@@ -825,7 +841,7 @@ def create_weights(
825841
w13_weight = torch.nn.Parameter(
826842
torch.empty(
827843
num_experts,
828-
2 * intermediate_size_per_partition,
844+
w13_up_dim,
829845
hidden_size,
830846
dtype=params_dtype,
831847
),
@@ -835,12 +851,16 @@ def create_weights(
835851
torch.empty(
836852
num_experts,
837853
hidden_size,
838-
intermediate_size_per_partition,
854+
w2_up_dim,
839855
dtype=params_dtype,
840856
),
841857
requires_grad=False,
842858
)
843859

860+
extra_weight_attrs.update(
861+
{"weight_padded": weight_padded},
862+
)
863+
844864
layer.register_parameter("w13_weight", w13_weight)
845865
set_weight_attrs(w13_weight, extra_weight_attrs)
846866

@@ -1401,10 +1421,7 @@ def process_weights_hip_int4(self, layer: Module):
14011421
layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id]
14021422

14031423
def process_weights_hip_scale_padding(self, layer: Module):
1404-
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
1405-
padding_size, # Avoid circular import
1406-
)
1407-
1424+
padding_size = get_moe_padding_size(_use_aiter)
14081425
if _use_aiter:
14091426
layer.w13_weight = torch.nn.Parameter(
14101427
shuffle_weight(layer.w13_weight.data, (16, 16)),

python/sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4_moe.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99

1010
from sglang.srt.layers.moe import MoeRunnerConfig
11+
from sglang.srt.layers.moe.utils import get_moe_weight_sizes
1112
from sglang.srt.layers.quantization.quark.schemes import QuarkMoEScheme
1213
from sglang.srt.utils import (
1314
get_bool_env_var,
@@ -73,10 +74,20 @@ def create_weights(
7374

7475
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
7576

77+
w13_up_dim, w2_down_dim, weight_padded = get_moe_weight_sizes(
78+
intermediate_size_per_partition,
79+
is_aiter_moe=True,
80+
is_concat=True,
81+
is_packed=True,
82+
)
83+
7684
# Add the quantization method used (per tensor/grouped/channel)
7785
# to ensure the weight scales are loaded in properly
7886
extra_weight_attrs.update(
79-
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
87+
{
88+
"quant_method": FusedMoeWeightScaleSupported.BLOCK.value,
89+
"weight_padded": weight_padded,
90+
},
8091
)
8192

8293
params_dtype = torch.uint8
@@ -85,7 +96,7 @@ def create_weights(
8596
w13_weight = torch.nn.Parameter(
8697
torch.empty(
8798
num_experts,
88-
2 * intermediate_size_per_partition,
99+
w13_up_dim,
89100
hidden_size // 2,
90101
dtype=params_dtype,
91102
),
@@ -99,7 +110,7 @@ def create_weights(
99110
torch.empty(
100111
num_experts,
101112
hidden_size,
102-
intermediate_size_per_partition // 2,
113+
w2_down_dim,
103114
dtype=params_dtype,
104115
),
105116
requires_grad=False,
@@ -112,17 +123,24 @@ def create_weights(
112123
w13_weight_scale = torch.nn.Parameter(
113124
torch.ones(
114125
num_experts,
115-
2 * intermediate_size_per_partition,
126+
w13_up_dim,
116127
hidden_size // OCP_MX_BLOCK_SIZE,
117128
dtype=params_dtype,
118129
),
119130
requires_grad=False,
120131
)
132+
133+
W2_SCALE_DIVIDEND = w2_down_dim * 2
134+
W2_SCALE_DIVISOR = intermediate_size_per_partition
135+
scaling_up = lambda dividend, divisor: (dividend * W2_SCALE_DIVIDEND) // (
136+
divisor * W2_SCALE_DIVISOR
137+
)
138+
121139
w2_weight_scale = torch.nn.Parameter(
122140
torch.ones(
123141
num_experts,
124142
hidden_size,
125-
intermediate_size_per_partition // OCP_MX_BLOCK_SIZE,
143+
scaling_up(intermediate_size_per_partition, OCP_MX_BLOCK_SIZE),
126144
dtype=params_dtype,
127145
),
128146
requires_grad=False,

python/sglang/srt/model_executor/model_runner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@
162162
empty_context,
163163
enable_show_time_cost,
164164
get_available_gpu_memory,
165+
get_bool_env_var,
165166
get_cpu_ids_by_node,
166167
init_custom_process_group,
167168
is_hip,
@@ -198,6 +199,7 @@
198199
_is_npu = is_npu()
199200
_is_cpu_amx_available = cpu_has_amx_support()
200201
_is_cpu_arm64 = is_host_cpu_arm64()
202+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
201203

202204
if _is_npu:
203205
from sglang.srt.hardware_backend.npu.utils import init_npu_backend
@@ -799,7 +801,9 @@ def check_quantized_moe_compatibility(self):
799801
f"moe_intermediate_size {moe_intermediate_size} must be divisible by moe_tp_size ({moe_tp_size}) which is tp_size ({self.tp_size}) divided by moe_ep_size ({self.moe_ep_size})."
800802
)
801803

802-
if (moe_intermediate_size // moe_tp_size) % weight_block_size_n != 0:
804+
if (
805+
moe_intermediate_size // moe_tp_size
806+
) % weight_block_size_n != 0 and not _use_aiter:
803807
raise ValueError(
804808
f"For quantized MoE models, please make sure ({moe_intermediate_size=} / {moe_tp_size=}) % {weight_block_size_n=} == 0 "
805809
f"where moe_tp_size is equal to tp_size ({self.tp_size}) divided by ep_size ({self.moe_ep_size}). "

0 commit comments

Comments
 (0)