2929 weight_quantize_xpu ,
3030 xpu_moe_layer ,
3131)
32- from fastdeploy .model_executor .utils import default_weight_loader , set_weight_attrs
32+ from fastdeploy .model_executor .utils import (
33+ TensorTracker ,
34+ default_weight_loader ,
35+ free_tensor ,
36+ set_weight_attrs ,
37+ )
3338
3439
3540class XPUMoEMethod (MoEMethodBase ):
@@ -62,15 +67,17 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
6267 """
6368 create weight process.
6469 """
65- if layer .fd_config .load_config .load_choices == "default_v1" and self .moe_quant_type in ["w16a16" ]:
70+ if layer .fd_config .load_config .load_choices == "default_v1" and self .moe_quant_type in [
71+ "w16a16" ,
72+ "weight_only_int8" ,
73+ "weight_only_int4" ,
74+ ]:
6675 self .up_gate_proj_weight_shape = [
6776 layer .num_local_experts ,
6877 layer .moe_intermediate_size * 2 ,
6978 layer .hidden_size ,
7079 ]
7180 self .down_proj_weight_shape = [layer .num_local_experts , layer .hidden_size , layer .moe_intermediate_size ]
72- extra_weight_attrs = {** extra_weight_attrs , "SHARD_ID_TO_SHARDED_DIM" : {"gate" : 0 , "down" : 1 , "up" : 0 }}
73-
7481 layer .up_gate_proj_weight = layer .create_parameter (
7582 shape = self .up_gate_proj_weight_shape ,
7683 dtype = layer .weight_dtype ,
@@ -86,18 +93,21 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
8693 set_weight_attrs (
8794 layer .up_gate_proj_weight ,
8895 {
96+ "SHARD_ID_TO_SHARDED_DIM" : {"gate" : 0 , "down" : 1 , "up" : 0 },
8997 "weight_loader" : extra_weight_attrs .get ("weight_loader" , default_weight_loader (layer .fd_config )),
9098 "weight_need_transpose" : extra_weight_attrs .get ("model_format" ) == "torch" ,
99+ "tensor_track" : TensorTracker (shape = layer .up_gate_proj_weight .shape , output_dim = False ),
91100 },
92101 )
93102 set_weight_attrs (
94103 layer .down_proj_weight ,
95104 {
105+ "SHARD_ID_TO_SHARDED_DIM" : {"gate" : 0 , "down" : 1 , "up" : 0 },
96106 "weight_loader" : extra_weight_attrs .get ("weight_loader" , default_weight_loader (layer .fd_config )),
97107 "weight_need_transpose" : extra_weight_attrs .get ("model_format" ) == "torch" ,
108+ "tensor_track" : TensorTracker (shape = layer .down_proj_weight .shape , output_dim = True ),
98109 },
99110 )
100-
101111 if layer .with_bias :
102112 layer .up_gate_proj_bias = layer .create_parameter (
103113 shape = [layer .num_experts , layer .moe_intermediate_size * 2 ],
@@ -128,6 +138,15 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
128138 "model_format" : extra_weight_attrs .get ("model_format" , "" ),
129139 },
130140 )
141+ if self .moe_quant_type in ["weight_only_int8" , "weight_only_int4" ]:
142+ self .up_gate_proj_scale_shape = [
143+ layer .num_local_experts ,
144+ layer .moe_intermediate_size * 2 ,
145+ ]
146+ self .down_proj_scale_shape = [
147+ layer .num_local_experts ,
148+ layer .hidden_size ,
149+ ]
131150
132151 else :
133152 self .up_gate_proj_weight_shape = [
@@ -531,6 +550,87 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict):
531550 quanted_weight_scale = paddle .stack (weight_scale_list , axis = 0 )
532551 getattr (layer , scale_name ).set_value (quanted_weight_scale )
533552
553+ def process_weights_after_loading (self , layer ):
554+ """ """
555+ if not self .quant_config .is_checkpoint_bf16 :
556+ return
557+ weight_id_map = {"gate_up" : 0 , "down" : 1 }
558+ if (
559+ hasattr (layer .up_gate_proj_weight , "tensor_track" )
560+ and layer .up_gate_proj_weight .tensor_track is not None
561+ and layer .up_gate_proj_weight .tensor_track .is_fully_copied ()
562+ ):
563+ weight_type = "gate_up"
564+ else :
565+ weight_type = "down"
566+
567+ # 1.init shape and type
568+ # weight
569+ weight_name = self .added_weight_attrs [weight_id_map [weight_type ]]
570+ unquantized_weight_name = weight_name .replace ("quant_weight" , "weight" )
571+ if weight_type == "gate_up" :
572+ weight_shape = [
573+ layer .num_local_experts ,
574+ layer .moe_intermediate_size * 2 ,
575+ layer .hidden_size ,
576+ ]
577+ else :
578+ weight_shape = [
579+ layer .num_local_experts ,
580+ layer .hidden_size ,
581+ layer .moe_intermediate_size ,
582+ ]
583+ weight_dtype = "int8"
584+ # scale
585+ scale_name = self .added_scale_attrs [weight_id_map [weight_type ]]
586+ scale_shape = self .up_gate_proj_scale_shape if weight_type == "gate_up" else self .down_proj_scale_shape
587+ if self .moe_quant_type in ["weight_only_int4" ]:
588+ weight_shape [- 1 ] //= 2
589+ scale_dtype = "float32"
590+
591+ # 2.crate tmp tensor
592+
593+ # weight = paddle.empty(weight_shape, dtype=weight_dtype)
594+ # scale = paddle.empty(scale_shape, dtype=scale_dtype)
595+
596+ # 3.quantize weight
597+ weight_list = []
598+ weight_scale_list = []
599+ for expert_id in range (layer .num_local_experts ):
600+ quant_weight , scale = weight_quantize_xpu (
601+ getattr (layer , unquantized_weight_name )[expert_id ].transpose ([1 , 0 ]), self .moe_quant_type , - 1 , - 1
602+ )
603+ weight_list .append (quant_weight .transpose ([1 , 0 ]))
604+ weight_scale_list .append (scale )
605+ quanted_weight = paddle .stack (weight_list , axis = 0 )
606+ quanted_weight_scale = paddle .stack (weight_scale_list , axis = 0 )
607+
608+ free_tensor (getattr (layer , unquantized_weight_name ))
609+
610+ # create weight
611+ setattr (
612+ layer ,
613+ weight_name ,
614+ layer .create_parameter (
615+ shape = weight_shape ,
616+ dtype = weight_dtype ,
617+ default_initializer = paddle .nn .initializer .Constant (0 ),
618+ ),
619+ )
620+ # create scale
621+ setattr (
622+ layer ,
623+ scale_name ,
624+ layer .create_parameter (
625+ shape = scale_shape ,
626+ dtype = scale_dtype ,
627+ default_initializer = paddle .nn .initializer .Constant (0 ),
628+ ),
629+ )
630+
631+ getattr (layer , weight_name ).set_value (quanted_weight )
632+ getattr (layer , scale_name ).set_value (quanted_weight_scale )
633+
534634
535635class XPUW4A8MoEMethod (XPUMoEMethod ):
536636 """
0 commit comments