@@ -42,6 +42,8 @@ def dummy_func(*args, **kwargs):
42
42
QuantizeMethodBase ,
43
43
)
44
44
from sglang .srt .layers .quantization .fp8_kernel import (
45
+ fp8_dtype ,
46
+ is_fp8_fnuz ,
45
47
per_token_group_quant_fp8 ,
46
48
scaled_fp8_quant ,
47
49
)
@@ -71,6 +73,11 @@ def dummy_func(*args, **kwargs):
71
73
_is_hip = is_hip ()
72
74
_is_cuda = is_cuda ()
73
75
76
+ _is_fp8_fnuz = is_fp8_fnuz ()
77
+
78
+ use_hip_int4 = get_bool_env_var ("SGLANG_INT4_WEIGHT" )
79
+ use_aiter_moe = get_bool_env_var ("SGLANG_AITER_MOE" )
80
+
74
81
if _is_hip :
75
82
from aiter import ActivationType , QuantType
76
83
from aiter .fused_moe_bf16_asm import asm_moe , ck_moe_2stages
@@ -306,25 +313,21 @@ def process_weights_after_loading(self, layer: Module) -> None:
306
313
# Block quant doesn't need to process weights after loading
307
314
if self .block_quant :
308
315
# If ROCm, normalize the weights and scales to e4m3fnuz
309
- if _is_hip :
316
+ if _is_fp8_fnuz :
310
317
# activation_scheme: dynamic
311
318
weight , weight_scale , _ = normalize_e4m3fn_to_e4m3fnuz (
312
319
weight = layer .weight ,
313
320
weight_scale = layer .weight_scale_inv ,
314
321
input_scale = None ,
315
322
)
316
- layer .weight = torch .nn .Parameter (weight , requires_grad = False )
317
- layer .weight_scale_inv = torch .nn .Parameter (
318
- weight_scale , requires_grad = False
319
- )
323
+
320
324
layer .input_scale = None
321
325
else :
322
- layer .weight = torch .nn .Parameter (
323
- layer .weight .data , requires_grad = False
324
- )
325
- layer .weight_scale_inv = torch .nn .Parameter (
326
- layer .weight_scale_inv .data , requires_grad = False
327
- )
326
+ weight , weight_scale = layer .weight .data , layer .weight_scale_inv .data
327
+ layer .weight = torch .nn .Parameter (weight , requires_grad = False )
328
+ layer .weight_scale_inv = torch .nn .Parameter (
329
+ weight_scale , requires_grad = False
330
+ )
328
331
return
329
332
330
333
layer .weight = torch .nn .Parameter (layer .weight .data , requires_grad = False )
@@ -368,7 +371,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
368
371
weight = layer .weight
369
372
weight_scale = layer .weight_scale
370
373
# If ROCm, normalize the weights and scales to e4m3fnuz
371
- if _is_hip :
374
+ if _is_fp8_fnuz :
372
375
weight , weight_scale , input_scale = normalize_e4m3fn_to_e4m3fnuz (
373
376
weight = weight ,
374
377
weight_scale = weight_scale ,
@@ -482,11 +485,7 @@ def create_weights(
482
485
from sglang .srt .layers .moe .fused_moe_triton import FusedMoeWeightScaleSupported
483
486
484
487
if self .quant_config .is_checkpoint_fp8_serialized :
485
- params_dtype = (
486
- torch .uint32
487
- if get_bool_env_var ("SGLANG_INT4_WEIGHT" )
488
- else torch .float8_e4m3fn
489
- )
488
+ params_dtype = torch .uint32 if use_hip_int4 else torch .float8_e4m3fn
490
489
tp_size = get_tensor_model_parallel_world_size ()
491
490
if self .block_quant :
492
491
block_n , block_k = (
@@ -511,7 +510,7 @@ def create_weights(
511
510
)
512
511
513
512
# WEIGHTS
514
- if _is_hip and get_bool_env_var ( "SGLANG_INT4_WEIGHT" ) :
513
+ if _is_hip and use_hip_int4 :
515
514
# INT4 MoE weight - INT32 packed
516
515
w13_weight = torch .nn .Parameter (
517
516
torch .empty (
@@ -583,9 +582,7 @@ def create_weights(
583
582
layer .register_parameter ("w13_weight_scale" , w13_weight_scale )
584
583
layer .register_parameter ("w2_weight_scale" , w2_weight_scale )
585
584
586
- if (
587
- _is_hip
588
- ): # and get_bool_env_var("SGLANG_AITER_MOE"): TODO: add check back after triton kernel
585
+ if _is_hip : # and use_aiter_moe: TODO: add check back after triton kernel
589
586
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
590
587
w13_weight_scale1 = torch .nn .Parameter (
591
588
torch .ones (num_experts , 2 * intermediate_size , dtype = torch .float32 ),
@@ -612,7 +609,7 @@ def create_weights(
612
609
set_weight_attrs (w13_weight_scale , extra_weight_attrs )
613
610
set_weight_attrs (w2_weight_scale , extra_weight_attrs )
614
611
615
- if _is_hip and get_bool_env_var ( "SGLANG_INT4_WEIGHT" ) :
612
+ if _is_hip and use_hip_int4 :
616
613
extra_weight_attrs .update (
617
614
{"quant_method" : FusedMoeWeightScaleSupported .CHANNEL .value }
618
615
)
@@ -644,14 +641,14 @@ def create_weights(
644
641
layer .w2_input_scale = None
645
642
646
643
def process_weights_after_loading (self , layer : Module ) -> None :
647
- if _is_hip and get_bool_env_var ( "SGLANG_INT4_WEIGHT" ) :
644
+ if _is_hip and use_hip_int4 :
648
645
self .process_weights_hip_int4 (layer )
649
646
return
650
647
651
648
# Block quant doesn't need to process weights after loading
652
649
if self .block_quant :
653
650
# If ROCm, normalize the weights and scales to e4m3fnuz
654
- if _is_hip :
651
+ if _is_fp8_fnuz :
655
652
# activation_scheme: dynamic
656
653
w13_weight , w13_weight_scale , _ = normalize_e4m3fn_to_e4m3fnuz (
657
654
weight = layer .w13_weight ,
@@ -675,20 +672,19 @@ def process_weights_after_loading(self, layer: Module) -> None:
675
672
)
676
673
layer .w2_input_scale = None
677
674
678
- if get_bool_env_var ( "SGLANG_AITER_MOE" ) :
679
- # Pre-shuffle weights
680
- layer .w13_weight .data = shuffle_weight (
681
- layer .w13_weight .contiguous (), (16 , 16 )
682
- )
683
- layer .w2_weight .data = shuffle_weight (
684
- layer .w2_weight .contiguous (), (16 , 16 )
685
- )
675
+ if _is_hip and use_aiter_moe :
676
+ # Pre-shuffle weights
677
+ layer .w13_weight .data = shuffle_weight (
678
+ layer .w13_weight .contiguous (), (16 , 16 )
679
+ )
680
+ layer .w2_weight .data = shuffle_weight (
681
+ layer .w2_weight .contiguous (), (16 , 16 )
682
+ )
686
683
return
687
684
688
685
# If checkpoint is fp16 or bfloat16, quantize in place.
689
686
if not self .quant_config .is_checkpoint_fp8_serialized :
690
- # If ROCm, use float8_e4m3fnuz instead (MI300x HW)
691
- fp8_dtype = torch .float8_e4m3fnuz if _is_hip else torch .float8_e4m3fn
687
+ # If ROCm, fp8_dtype will be float8_e4m3fnuz (MI300x HW)
692
688
w13_weight = torch .empty_like (layer .w13_weight .data , dtype = fp8_dtype )
693
689
w2_weight = torch .empty_like (layer .w2_weight .data , dtype = fp8_dtype )
694
690
@@ -742,7 +738,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
742
738
)
743
739
744
740
# If ROCm, normalize the weights and scales to e4m3fnuz
745
- if _is_hip :
741
+ if _is_fp8_fnuz :
746
742
# Normalize the weights and scales
747
743
w13_weight , w13_weight_scale , w13_input_scale = (
748
744
normalize_e4m3fn_to_e4m3fnuz (
@@ -798,7 +794,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
798
794
return
799
795
800
796
def process_weights_hip_int4 (self , layer : Module ):
801
- # TODO: and get_bool_env_var("SGLANG_AITER_MOE") : add after triton kernel added
797
+ # TODO: and use_aiter_moe : add after triton kernel added
802
798
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
803
799
# Weight Permutation
804
800
layer .w13_weight = torch .nn .Parameter (
@@ -845,7 +841,7 @@ def process_weights_hip_scale_padding(self, layer: Module):
845
841
padding_size , # Avoid circular import
846
842
)
847
843
848
- if get_bool_env_var ( "SGLANG_AITER_MOE" ) :
844
+ if use_aiter_moe :
849
845
layer .w13_weight = torch .nn .Parameter (
850
846
shuffle_weight (layer .w13_weight .data , (16 , 16 )),
851
847
requires_grad = False ,
@@ -856,7 +852,7 @@ def process_weights_hip_scale_padding(self, layer: Module):
856
852
requires_grad = False ,
857
853
)
858
854
torch .cuda .empty_cache ()
859
- # ROCm (SGLANG_AITER_MOE ): using column-wise scaling
855
+ # ROCm (use_aiter_moe ): using column-wise scaling
860
856
layer .w13_weight_scale1 *= layer .w13_weight_scale .unsqueeze (- 1 )
861
857
layer .w2_weight_scale1 *= layer .w2_weight_scale .unsqueeze (- 1 )
862
858
elif get_bool_env_var ("SGLANG_MOE_PADDING" ):
@@ -908,59 +904,16 @@ def apply(
908
904
)
909
905
910
906
if _is_hip :
911
- if get_bool_env_var ("SGLANG_INT4_WEIGHT" ):
912
- # TODO: add triton kernel and add check get_bool_env_var("SGLANG_AITER_MOE")
913
- assert not no_combine , f"{ no_combine = } is not supported."
914
- return ck_moe_2stages (
915
- x ,
916
- layer .w13_weight ,
917
- layer .w2_weight ,
918
- topk_weights ,
919
- topk_ids ,
920
- QuantType .per_Token ,
921
- layer .w13_weight_scale1 ,
922
- layer .w2_weight_scale1 ,
923
- activation = (
924
- ActivationType .Silu
925
- if activation == "silu"
926
- else ActivationType .Gelu
927
- ),
928
- )
929
-
930
- if get_bool_env_var ("SGLANG_AITER_MOE" ):
931
- assert not no_combine , f"{ no_combine = } is not supported."
932
- if self .block_quant :
933
- # TODO(SGLANG_AITER_MOE): FP8 block_quant only supports 'silu' for the time-being.
934
- assert (
935
- activation == "silu"
936
- ), f"SGLANG_AITER_MOE: FP8 bloack_quant { activation = } will be supported later, unset SGLANG_AITER_MOE"
937
- return asm_moe (
938
- x ,
939
- layer .w13_weight ,
940
- layer .w2_weight ,
941
- topk_weights ,
942
- topk_ids ,
943
- layer .w13_weight_scale_inv ,
944
- layer .w2_weight_scale_inv ,
945
- block_shape = tuple (self .quant_config .weight_block_size ),
946
- expert_mask = None ,
947
- )
948
- else :
949
- return ck_moe_2stages (
950
- x ,
951
- layer .w13_weight ,
952
- layer .w2_weight ,
953
- topk_weights ,
954
- topk_ids ,
955
- QuantType .per_Token ,
956
- layer .w13_weight_scale1 ,
957
- layer .w2_weight_scale1 ,
958
- activation = (
959
- ActivationType .Silu
960
- if activation == "silu"
961
- else ActivationType .Gelu
962
- ),
963
- )
907
+ ret = self .maybe_apply_hip_fused_experts (
908
+ layer ,
909
+ x ,
910
+ topk_weights ,
911
+ topk_ids ,
912
+ activation ,
913
+ no_combine ,
914
+ )
915
+ if ret is not None :
916
+ return ret
964
917
965
918
# Expert fusion with FP8 quantization
966
919
return fused_experts (
@@ -987,6 +940,68 @@ def apply(
987
940
no_combine = no_combine ,
988
941
)
989
942
943
+ def maybe_apply_hip_fused_experts (
944
+ self ,
945
+ layer : torch .nn .Module ,
946
+ x : torch .Tensor ,
947
+ topk_weights : torch .Tensor ,
948
+ topk_ids : torch .Tensor ,
949
+ activation : str = "silu" ,
950
+ no_combine : bool = False ,
951
+ ) -> Optional [torch .Tensor ]:
952
+ if use_hip_int4 :
953
+ # TODO: add triton kernel and add check use_aiter_moe
954
+ assert not no_combine , f"{ no_combine = } is not supported."
955
+ return ck_moe_2stages (
956
+ x ,
957
+ layer .w13_weight ,
958
+ layer .w2_weight ,
959
+ topk_weights ,
960
+ topk_ids ,
961
+ QuantType .per_Token ,
962
+ layer .w13_weight_scale1 ,
963
+ layer .w2_weight_scale1 ,
964
+ activation = (
965
+ ActivationType .Silu if activation == "silu" else ActivationType .Gelu
966
+ ),
967
+ )
968
+
969
+ if use_aiter_moe :
970
+ assert not no_combine , f"{ no_combine = } is not supported."
971
+ if self .block_quant :
972
+ # TODO(use_aiter_moe): FP8 block_quant only supports 'silu' for the time-being.
973
+ assert (
974
+ activation == "silu"
975
+ ), f"use_aiter_moe: FP8 bloack_quant { activation = } will be supported later, unset use_aiter_moe"
976
+ return asm_moe (
977
+ x ,
978
+ layer .w13_weight ,
979
+ layer .w2_weight ,
980
+ topk_weights ,
981
+ topk_ids ,
982
+ layer .w13_weight_scale_inv ,
983
+ layer .w2_weight_scale_inv ,
984
+ block_shape = tuple (self .quant_config .weight_block_size ),
985
+ expert_mask = None ,
986
+ )
987
+ else :
988
+ return ck_moe_2stages (
989
+ x ,
990
+ layer .w13_weight ,
991
+ layer .w2_weight ,
992
+ topk_weights ,
993
+ topk_ids ,
994
+ QuantType .per_Token ,
995
+ layer .w13_weight_scale1 ,
996
+ layer .w2_weight_scale1 ,
997
+ activation = (
998
+ ActivationType .Silu
999
+ if activation == "silu"
1000
+ else ActivationType .Gelu
1001
+ ),
1002
+ )
1003
+ return None
1004
+
990
1005
991
1006
class Fp8KVCacheMethod (BaseKVCacheMethod ):
992
1007
"""
0 commit comments