2121
2222import fastdeploy
2323from fastdeploy .model_executor .ops .gpu import moe_expert_dispatch , moe_expert_reduce
24+ from fastdeploy .model_executor .utils import set_weight_attrs
2425from fastdeploy .utils import ceil_div
2526
2627from ..quantization .quant_base import QuantMethodBase
@@ -154,6 +155,22 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
154155 default_initializer = paddle .nn .initializer .Constant (0 ),
155156 ),
156157 )
158+ for weight_name in [
159+ "up_gate_proj_weight" ,
160+ "down_proj_weight" ,
161+ "up_gate_proj_weight_scale" ,
162+ "down_proj_weight_scale" ,
163+ "up_gate_proj_super_scales" ,
164+ "down_proj_super_scales" ,
165+ "up_gate_proj_code_scale" ,
166+ "down_proj_code_scale" ,
167+ "up_gate_proj_code_zp" ,
168+ "down_proj_code_zp" ,
169+ ]:
170+ set_weight_attrs (
171+ getattr (layer , weight_name ),
172+ extra_weight_attrs ,
173+ )
157174
158175
159176class CutlassWint2FusedMoeMethod (Wint2MoeMethod ):
@@ -164,6 +181,24 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
164181 def __init__ (self , quant_config ):
165182 super ().__init__ (quant_config )
166183
184+ def process_weights_after_loading (self , layer ):
185+ if self .quant_config .is_checkpoint_bf16 :
186+ # dynamic quantize
187+ return
188+ w1_shape = layer .up_gate_proj_weight .shape
189+ up_gate_proj_weight = layer .up_gate_proj_weight .reshape (
190+ [w1_shape [0 ], w1_shape [1 ] // 16 , 16 , w1_shape [2 ] // 8 , 8 ]
191+ )
192+ up_gate_proj_weight = paddle .transpose (up_gate_proj_weight , perm = [0 , 3 , 1 , 4 , 2 ])
193+ up_gate_proj_weight = up_gate_proj_weight .reshape (w1_shape )
194+ layer .up_gate_proj_weight .data = up_gate_proj_weight
195+
196+ w2_shape = layer .down_proj_weight .shape
197+ down_proj_weight = layer .down_proj_weight .reshape ([w2_shape [0 ], w2_shape [1 ] // 16 , 16 , w2_shape [2 ] // 8 , 8 ])
198+ down_proj_weight = paddle .transpose (down_proj_weight , perm = [0 , 3 , 1 , 4 , 2 ])
199+ down_proj_weight = down_proj_weight .reshape (w2_shape )
200+ layer .down_proj_weight .data = down_proj_weight
201+
167202 def process_loaded_weights (self , layer , weights ) -> None :
168203 """
169204 process_loaded_weights
0 commit comments