Skip to content

Commit cf91c0d

Browse files
AlcanderianLayssy
authored andcommitted
[refactor] slightly tidy fp8 module (sgl-project#5993)
1 parent 70b1c6b commit cf91c0d

File tree

12 files changed

+239
-232
lines changed

12 files changed

+239
-232
lines changed

python/sglang/srt/layers/moe/ep_moe/kernels.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
_is_cuda = is_cuda()
1313
if _is_cuda:
1414
from sglang.srt.layers.quantization.fp8_kernel import (
15-
sglang_per_token_group_quant_fp8,
15+
sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8,
1616
)
1717
logger = logging.getLogger(__name__)
1818

@@ -654,10 +654,7 @@ def grouped_gemm_triton(
654654
if block_shape is not None:
655655
assert len(block_shape) == 2
656656
block_n, block_k = block_shape[0], block_shape[1]
657-
if _is_cuda:
658-
a, scale_a = sglang_per_token_group_quant_fp8(a, block_k)
659-
else:
660-
a, scale_a = per_token_group_quant_fp8(a, block_k)
657+
a, scale_a = per_token_group_quant_fp8(a, block_k)
661658

662659
assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1]
663660
assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]

python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,14 @@
1010
from compressed_tensors import CompressionFormat
1111
from compressed_tensors.quantization import QuantizationStrategy
1212

13-
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
13+
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
1414
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
1515
from sglang.srt.layers.quantization.utils import (
1616
all_close_1d,
17-
is_cuda,
18-
is_fp8_fnuz,
1917
per_tensor_dequantize,
2018
replace_parameter,
2119
)
22-
from sglang.srt.utils import set_weight_attrs
20+
from sglang.srt.utils import is_cuda, set_weight_attrs
2321

2422
_is_cuda = is_cuda()
2523

python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
1616
CompressedTensorsScheme,
1717
)
18+
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
1819
from sglang.srt.layers.quantization.fp8_utils import (
1920
apply_fp8_linear,
2021
normalize_e4m3fn_to_e4m3fnuz,
2122
)
22-
from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale
23+
from sglang.srt.layers.quantization.utils import requantize_with_max_scale
2324

2425
__all__ = ["CompressedTensorsW8A8Fp8"]
2526

python/sglang/srt/layers/quantization/fp8.py

Lines changed: 106 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def dummy_func(*args, **kwargs):
4242
QuantizeMethodBase,
4343
)
4444
from sglang.srt.layers.quantization.fp8_kernel import (
45+
fp8_dtype,
46+
is_fp8_fnuz,
4547
per_token_group_quant_fp8,
4648
scaled_fp8_quant,
4749
)
@@ -71,6 +73,11 @@ def dummy_func(*args, **kwargs):
7173
_is_hip = is_hip()
7274
_is_cuda = is_cuda()
7375

76+
_is_fp8_fnuz = is_fp8_fnuz()
77+
78+
use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT")
79+
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
80+
7481
if _is_hip:
7582
from aiter import ActivationType, QuantType
7683
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
@@ -306,25 +313,21 @@ def process_weights_after_loading(self, layer: Module) -> None:
306313
# Block quant doesn't need to process weights after loading
307314
if self.block_quant:
308315
# If ROCm, normalize the weights and scales to e4m3fnuz
309-
if _is_hip:
316+
if _is_fp8_fnuz:
310317
# activation_scheme: dynamic
311318
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
312319
weight=layer.weight,
313320
weight_scale=layer.weight_scale_inv,
314321
input_scale=None,
315322
)
316-
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
317-
layer.weight_scale_inv = torch.nn.Parameter(
318-
weight_scale, requires_grad=False
319-
)
323+
320324
layer.input_scale = None
321325
else:
322-
layer.weight = torch.nn.Parameter(
323-
layer.weight.data, requires_grad=False
324-
)
325-
layer.weight_scale_inv = torch.nn.Parameter(
326-
layer.weight_scale_inv.data, requires_grad=False
327-
)
326+
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
327+
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
328+
layer.weight_scale_inv = torch.nn.Parameter(
329+
weight_scale, requires_grad=False
330+
)
328331
return
329332

330333
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
@@ -368,7 +371,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
368371
weight = layer.weight
369372
weight_scale = layer.weight_scale
370373
# If ROCm, normalize the weights and scales to e4m3fnuz
371-
if _is_hip:
374+
if _is_fp8_fnuz:
372375
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
373376
weight=weight,
374377
weight_scale=weight_scale,
@@ -482,11 +485,7 @@ def create_weights(
482485
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
483486

484487
if self.quant_config.is_checkpoint_fp8_serialized:
485-
params_dtype = (
486-
torch.uint32
487-
if get_bool_env_var("SGLANG_INT4_WEIGHT")
488-
else torch.float8_e4m3fn
489-
)
488+
params_dtype = torch.uint32 if use_hip_int4 else torch.float8_e4m3fn
490489
tp_size = get_tensor_model_parallel_world_size()
491490
if self.block_quant:
492491
block_n, block_k = (
@@ -511,7 +510,7 @@ def create_weights(
511510
)
512511

513512
# WEIGHTS
514-
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
513+
if _is_hip and use_hip_int4:
515514
# INT4 MoE weight - INT32 packed
516515
w13_weight = torch.nn.Parameter(
517516
torch.empty(
@@ -583,9 +582,7 @@ def create_weights(
583582
layer.register_parameter("w13_weight_scale", w13_weight_scale)
584583
layer.register_parameter("w2_weight_scale", w2_weight_scale)
585584

586-
if (
587-
_is_hip
588-
): # and get_bool_env_var("SGLANG_AITER_MOE"): TODO: add check back after triton kernel
585+
if _is_hip: # and use_aiter_moe: TODO: add check back after triton kernel
589586
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
590587
w13_weight_scale1 = torch.nn.Parameter(
591588
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
@@ -612,7 +609,7 @@ def create_weights(
612609
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
613610
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
614611

615-
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
612+
if _is_hip and use_hip_int4:
616613
extra_weight_attrs.update(
617614
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
618615
)
@@ -644,14 +641,14 @@ def create_weights(
644641
layer.w2_input_scale = None
645642

646643
def process_weights_after_loading(self, layer: Module) -> None:
647-
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
644+
if _is_hip and use_hip_int4:
648645
self.process_weights_hip_int4(layer)
649646
return
650647

651648
# Block quant doesn't need to process weights after loading
652649
if self.block_quant:
653650
# If ROCm, normalize the weights and scales to e4m3fnuz
654-
if _is_hip:
651+
if _is_fp8_fnuz:
655652
# activation_scheme: dynamic
656653
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
657654
weight=layer.w13_weight,
@@ -675,20 +672,19 @@ def process_weights_after_loading(self, layer: Module) -> None:
675672
)
676673
layer.w2_input_scale = None
677674

678-
if get_bool_env_var("SGLANG_AITER_MOE"):
679-
# Pre-shuffle weights
680-
layer.w13_weight.data = shuffle_weight(
681-
layer.w13_weight.contiguous(), (16, 16)
682-
)
683-
layer.w2_weight.data = shuffle_weight(
684-
layer.w2_weight.contiguous(), (16, 16)
685-
)
675+
if _is_hip and use_aiter_moe:
676+
# Pre-shuffle weights
677+
layer.w13_weight.data = shuffle_weight(
678+
layer.w13_weight.contiguous(), (16, 16)
679+
)
680+
layer.w2_weight.data = shuffle_weight(
681+
layer.w2_weight.contiguous(), (16, 16)
682+
)
686683
return
687684

688685
# If checkpoint is fp16 or bfloat16, quantize in place.
689686
if not self.quant_config.is_checkpoint_fp8_serialized:
690-
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
691-
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
687+
# If ROCm, fp8_dtype will be float8_e4m3fnuz (MI300x HW)
692688
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
693689
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
694690

@@ -742,7 +738,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
742738
)
743739

744740
# If ROCm, normalize the weights and scales to e4m3fnuz
745-
if _is_hip:
741+
if _is_fp8_fnuz:
746742
# Normalize the weights and scales
747743
w13_weight, w13_weight_scale, w13_input_scale = (
748744
normalize_e4m3fn_to_e4m3fnuz(
@@ -798,7 +794,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
798794
return
799795

800796
def process_weights_hip_int4(self, layer: Module):
801-
# TODO: and get_bool_env_var("SGLANG_AITER_MOE"): add after triton kernel added
797+
# TODO: and use_aiter_moe: add after triton kernel added
802798
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
803799
# Weight Permutation
804800
layer.w13_weight = torch.nn.Parameter(
@@ -845,7 +841,7 @@ def process_weights_hip_scale_padding(self, layer: Module):
845841
padding_size, # Avoid circular import
846842
)
847843

848-
if get_bool_env_var("SGLANG_AITER_MOE"):
844+
if use_aiter_moe:
849845
layer.w13_weight = torch.nn.Parameter(
850846
shuffle_weight(layer.w13_weight.data, (16, 16)),
851847
requires_grad=False,
@@ -856,7 +852,7 @@ def process_weights_hip_scale_padding(self, layer: Module):
856852
requires_grad=False,
857853
)
858854
torch.cuda.empty_cache()
859-
# ROCm (SGLANG_AITER_MOE): using column-wise scaling
855+
# ROCm (use_aiter_moe): using column-wise scaling
860856
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
861857
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
862858
elif get_bool_env_var("SGLANG_MOE_PADDING"):
@@ -908,59 +904,16 @@ def apply(
908904
)
909905

910906
if _is_hip:
911-
if get_bool_env_var("SGLANG_INT4_WEIGHT"):
912-
# TODO: add triton kernel and add check get_bool_env_var("SGLANG_AITER_MOE")
913-
assert not no_combine, f"{no_combine=} is not supported."
914-
return ck_moe_2stages(
915-
x,
916-
layer.w13_weight,
917-
layer.w2_weight,
918-
topk_weights,
919-
topk_ids,
920-
QuantType.per_Token,
921-
layer.w13_weight_scale1,
922-
layer.w2_weight_scale1,
923-
activation=(
924-
ActivationType.Silu
925-
if activation == "silu"
926-
else ActivationType.Gelu
927-
),
928-
)
929-
930-
if get_bool_env_var("SGLANG_AITER_MOE"):
931-
assert not no_combine, f"{no_combine=} is not supported."
932-
if self.block_quant:
933-
# TODO(SGLANG_AITER_MOE): FP8 block_quant only supports 'silu' for the time-being.
934-
assert (
935-
activation == "silu"
936-
), f"SGLANG_AITER_MOE: FP8 bloack_quant {activation=} will be supported later, unset SGLANG_AITER_MOE"
937-
return asm_moe(
938-
x,
939-
layer.w13_weight,
940-
layer.w2_weight,
941-
topk_weights,
942-
topk_ids,
943-
layer.w13_weight_scale_inv,
944-
layer.w2_weight_scale_inv,
945-
block_shape=tuple(self.quant_config.weight_block_size),
946-
expert_mask=None,
947-
)
948-
else:
949-
return ck_moe_2stages(
950-
x,
951-
layer.w13_weight,
952-
layer.w2_weight,
953-
topk_weights,
954-
topk_ids,
955-
QuantType.per_Token,
956-
layer.w13_weight_scale1,
957-
layer.w2_weight_scale1,
958-
activation=(
959-
ActivationType.Silu
960-
if activation == "silu"
961-
else ActivationType.Gelu
962-
),
963-
)
907+
ret = self.maybe_apply_hip_fused_experts(
908+
layer,
909+
x,
910+
topk_weights,
911+
topk_ids,
912+
activation,
913+
no_combine,
914+
)
915+
if ret is not None:
916+
return ret
964917

965918
# Expert fusion with FP8 quantization
966919
return fused_experts(
@@ -987,6 +940,68 @@ def apply(
987940
no_combine=no_combine,
988941
)
989942

943+
def maybe_apply_hip_fused_experts(
944+
self,
945+
layer: torch.nn.Module,
946+
x: torch.Tensor,
947+
topk_weights: torch.Tensor,
948+
topk_ids: torch.Tensor,
949+
activation: str = "silu",
950+
no_combine: bool = False,
951+
) -> Optional[torch.Tensor]:
952+
if use_hip_int4:
953+
# TODO: add triton kernel and add check use_aiter_moe
954+
assert not no_combine, f"{no_combine=} is not supported."
955+
return ck_moe_2stages(
956+
x,
957+
layer.w13_weight,
958+
layer.w2_weight,
959+
topk_weights,
960+
topk_ids,
961+
QuantType.per_Token,
962+
layer.w13_weight_scale1,
963+
layer.w2_weight_scale1,
964+
activation=(
965+
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
966+
),
967+
)
968+
969+
if use_aiter_moe:
970+
assert not no_combine, f"{no_combine=} is not supported."
971+
if self.block_quant:
972+
# TODO(use_aiter_moe): FP8 block_quant only supports 'silu' for the time-being.
973+
assert (
974+
activation == "silu"
975+
), f"use_aiter_moe: FP8 bloack_quant {activation=} will be supported later, unset use_aiter_moe"
976+
return asm_moe(
977+
x,
978+
layer.w13_weight,
979+
layer.w2_weight,
980+
topk_weights,
981+
topk_ids,
982+
layer.w13_weight_scale_inv,
983+
layer.w2_weight_scale_inv,
984+
block_shape=tuple(self.quant_config.weight_block_size),
985+
expert_mask=None,
986+
)
987+
else:
988+
return ck_moe_2stages(
989+
x,
990+
layer.w13_weight,
991+
layer.w2_weight,
992+
topk_weights,
993+
topk_ids,
994+
QuantType.per_Token,
995+
layer.w13_weight_scale1,
996+
layer.w2_weight_scale1,
997+
activation=(
998+
ActivationType.Silu
999+
if activation == "silu"
1000+
else ActivationType.Gelu
1001+
),
1002+
)
1003+
return None
1004+
9901005

9911006
class Fp8KVCacheMethod(BaseKVCacheMethod):
9921007
"""

0 commit comments

Comments
 (0)